Connect MongoDB and require authentication
This commit is contained in:
@@ -15,6 +15,7 @@ var (
|
||||
Environment: "production",
|
||||
Host: "127.0.0.1",
|
||||
Port: 3001,
|
||||
MongoDB: nil,
|
||||
Redis: nil,
|
||||
Cache: ConfigCache{
|
||||
EnableLocks: true,
|
||||
@@ -22,18 +23,17 @@ var (
|
||||
BedrockStatusDuration: time.Minute,
|
||||
IconDuration: time.Minute * 15,
|
||||
},
|
||||
AccessControl: ConfigAccessControl{},
|
||||
}
|
||||
)
|
||||
|
||||
// Config represents the application configuration.
|
||||
type Config struct {
|
||||
Environment string `yaml:"environment"`
|
||||
Host string `yaml:"host"`
|
||||
Port uint16 `yaml:"port"`
|
||||
Redis *string `yaml:"redis"`
|
||||
Cache ConfigCache `yaml:"cache"`
|
||||
AccessControl ConfigAccessControl `yaml:"access_control"`
|
||||
Environment string `yaml:"environment"`
|
||||
Host string `yaml:"host"`
|
||||
Port uint16 `yaml:"port"`
|
||||
MongoDB *string `yaml:"mongodb"`
|
||||
Redis *string `yaml:"redis"`
|
||||
Cache ConfigCache `yaml:"cache"`
|
||||
}
|
||||
|
||||
// ConfigCache represents the caching durations of various responses.
|
||||
@@ -44,12 +44,6 @@ type ConfigCache struct {
|
||||
IconDuration time.Duration `yaml:"icon_duration"`
|
||||
}
|
||||
|
||||
// ConfigAccessControl is the configuration for the CORS headers
|
||||
type ConfigAccessControl struct {
|
||||
Enable bool `yaml:"enable"`
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
}
|
||||
|
||||
// ReadFile reads the configuration from the given file and overrides values using environment variables.
|
||||
func (c *Config) ReadFile(file string) error {
|
||||
data, err := os.ReadFile(file)
|
||||
@@ -99,5 +93,9 @@ func (c *Config) overrideWithEnvVars() error {
|
||||
c.Redis = &value
|
||||
}
|
||||
|
||||
if value := os.Getenv("MONGO_URL"); value != "" {
|
||||
c.MongoDB = &value
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
16
src/main.go
16
src/main.go
@@ -25,9 +25,10 @@ var (
|
||||
return ctx.SendStatus(http.StatusInternalServerError)
|
||||
},
|
||||
})
|
||||
r *Redis = &Redis{}
|
||||
config *Config = DefaultConfig
|
||||
instanceID uint16 = 0
|
||||
r *Redis = &Redis{}
|
||||
db *MongoDB = &MongoDB{}
|
||||
config *Config = DefaultConfig
|
||||
instanceID uint16 = 0
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -51,6 +52,14 @@ func init() {
|
||||
|
||||
log.Println("Successfully retrieved EULA blocked servers")
|
||||
|
||||
if config.MongoDB != nil {
|
||||
if err = db.Connect(); err != nil {
|
||||
log.Fatalf("Failed to connect to MongoDB: %v", err)
|
||||
}
|
||||
|
||||
log.Println("Successfully connected to MongoDB")
|
||||
}
|
||||
|
||||
if config.Redis != nil {
|
||||
if err = r.Connect(); err != nil {
|
||||
log.Fatalf("Failed to connect to Redis: %v", err)
|
||||
@@ -72,6 +81,7 @@ func init() {
|
||||
|
||||
func main() {
|
||||
defer r.Close()
|
||||
defer db.Close()
|
||||
|
||||
if err := app.Listen(fmt.Sprintf("%s:%d", config.Host, config.Port+instanceID)); err != nil {
|
||||
panic(err)
|
||||
|
||||
198
src/mongo.go
Normal file
198
src/mongo.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
var (
|
||||
CollectionApplications string = "applications"
|
||||
CollectionTokens string = "tokens"
|
||||
CollectionRequestLog string = "request_log"
|
||||
|
||||
ErrMongoNotConnected error = errors.New("cannot use method as MongoDB is not connected")
|
||||
)
|
||||
|
||||
type MongoDB struct {
|
||||
Client *mongo.Client
|
||||
Database *mongo.Database
|
||||
}
|
||||
|
||||
type Application struct {
|
||||
ID string `bson:"_id" json:"id"`
|
||||
Name string `bson:"name" json:"name"`
|
||||
ShortDescription string `bson:"shortDescription" json:"shortDescription"`
|
||||
User string `bson:"user" json:"user"`
|
||||
Token string `bson:"token" json:"token"`
|
||||
TotalRequests uint64 `bson:"totalRequests" json:"totalRequests"`
|
||||
CreatedAt time.Time `bson:"createdAt" json:"createdAt"`
|
||||
}
|
||||
|
||||
type Token struct {
|
||||
ID string `bson:"_id" json:"id"`
|
||||
Name string `bson:"name" json:"name"`
|
||||
Token string `bson:"token" json:"token"`
|
||||
TotalRequests uint64 `bson:"totalRequests" json:"totalRequests"`
|
||||
Application string `bson:"application" json:"application"`
|
||||
CreatedAt time.Time `bson:"createdAt" json:"createdAt"`
|
||||
LastUsedAt time.Time `bson:"lastUsedAt" json:"lastUsedAt"`
|
||||
}
|
||||
|
||||
func (c *MongoDB) Connect() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
defer cancel()
|
||||
|
||||
parsedURI, err := url.Parse(*config.MongoDB)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := mongo.Connect(ctx, options.Client().ApplyURI(*config.MongoDB))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Client = client
|
||||
c.Database = client.Database(strings.TrimPrefix(parsedURI.Path, "/"))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MongoDB) GetTokenByToken(token string) (*Token, error) {
|
||||
if c.Client == nil {
|
||||
return nil, ErrMongoNotConnected
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
defer cancel()
|
||||
|
||||
cur := c.Database.Collection(CollectionTokens).FindOne(ctx, bson.M{"token": token})
|
||||
|
||||
if err := cur.Err(); err != nil {
|
||||
if errors.Is(err, mongo.ErrNoDocuments) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result Token
|
||||
|
||||
if err := cur.Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *MongoDB) GetApplicationByID(id string) (*Application, error) {
|
||||
if c.Client == nil {
|
||||
return nil, ErrMongoNotConnected
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
defer cancel()
|
||||
|
||||
cur := c.Database.Collection(CollectionApplications).FindOne(ctx, bson.M{"_id": id})
|
||||
|
||||
if err := cur.Err(); err != nil {
|
||||
if errors.Is(err, mongo.ErrNoDocuments) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result Application
|
||||
|
||||
if err := cur.Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (c *MongoDB) UpdateToken(id string, update bson.M) error {
|
||||
if c.Client == nil {
|
||||
return ErrMongoNotConnected
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
defer cancel()
|
||||
|
||||
_, err := c.Database.Collection(CollectionTokens).UpdateOne(
|
||||
ctx,
|
||||
bson.M{"_id": id},
|
||||
update,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *MongoDB) IncrementApplicationRequestCount(id string) error {
|
||||
if c.Client == nil {
|
||||
return ErrMongoNotConnected
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
defer cancel()
|
||||
|
||||
_, err := c.Database.Collection(CollectionApplications).UpdateOne(
|
||||
ctx,
|
||||
bson.M{"_id": id},
|
||||
bson.M{
|
||||
"$inc": bson.M{
|
||||
"totalRequests": 1,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *MongoDB) UpsertRequestLog(query, update bson.M) error {
|
||||
if c.Client == nil {
|
||||
return ErrMongoNotConnected
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
defer cancel()
|
||||
|
||||
_, err := c.Database.Collection(CollectionRequestLog).UpdateOne(
|
||||
ctx,
|
||||
query,
|
||||
update,
|
||||
&options.UpdateOptions{
|
||||
Upsert: PointerOf(true),
|
||||
},
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *MongoDB) Close() error {
|
||||
if c.Client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
defer cancel()
|
||||
|
||||
return c.Client.Disconnect(ctx)
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"main/src/assets"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
@@ -26,15 +25,13 @@ func init() {
|
||||
Data: assets.Favicon,
|
||||
}))
|
||||
|
||||
if config.AccessControl.Enable {
|
||||
if config.Environment == "development" {
|
||||
app.Use(cors.New(cors.Config{
|
||||
AllowOrigins: strings.Join(config.AccessControl.AllowedOrigins, ","),
|
||||
AllowOrigins: "*",
|
||||
AllowMethods: "HEAD,OPTIONS,GET,POST",
|
||||
ExposeHeaders: "X-Cache-Hit,X-Cache-Time-Remaining",
|
||||
}))
|
||||
}
|
||||
|
||||
if config.Environment == "development" {
|
||||
app.Use(logger.New(logger.Config{
|
||||
Format: "${time} ${ip}:${port} -> ${status}: ${method} ${path} (${latency})\n",
|
||||
TimeFormat: "2006/01/02 15:04:05",
|
||||
@@ -68,6 +65,14 @@ func JavaStatusHandler(ctx *fiber.Ctx) error {
|
||||
return ctx.Status(http.StatusBadRequest).SendString("Invalid address value")
|
||||
}
|
||||
|
||||
authorized, err := Authenticate(ctx)
|
||||
|
||||
// This check should work for both scenarios, because nil should be returned if the user
|
||||
// is unauthorized, and err will be nil in that case.
|
||||
if err != nil || !authorized {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = r.Increment(fmt.Sprintf("java-hits:%s-%d", host, port)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
83
src/util.go
83
src/util.go
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
_ "embed"
|
||||
"encoding/hex"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -248,6 +250,7 @@ func GetVoteOptions(ctx *fiber.Ctx) (*VoteOptions, error) {
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// GetStatusOptions returns the options for status routes, with the default values filled in.
|
||||
func GetStatusOptions(ctx *fiber.Ctx) (*StatusOptions, error) {
|
||||
result := &StatusOptions{}
|
||||
|
||||
@@ -292,6 +295,70 @@ func GetCacheKey(host string, port uint16, opts *StatusOptions) string {
|
||||
return SHA256(values.Encode())
|
||||
}
|
||||
|
||||
// Authenticate checks and requires authentication for the current request, by finding the token.
|
||||
func Authenticate(ctx *fiber.Ctx) (bool, error) {
|
||||
if config.MongoDB == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
authToken := ctx.Get("Authorization")
|
||||
|
||||
if len(authToken) < 1 {
|
||||
if err := ctx.Status(http.StatusUnauthorized).SendString("Missing 'Authorization' header in request"); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
token, err := db.GetTokenByToken(authToken)
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if token == nil {
|
||||
if err := ctx.Status(http.StatusUnauthorized).SendString("Invalid or expired authorization token, please generate another one in the dashboard"); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err = db.IncrementApplicationRequestCount(token.Application); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err = db.UpdateToken(
|
||||
token.ID,
|
||||
bson.M{
|
||||
"$inc": bson.M{"requestCount": 1},
|
||||
"$set": bson.M{"lastUsedAt": time.Now().UTC()},
|
||||
},
|
||||
); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err = db.UpsertRequestLog(
|
||||
bson.M{
|
||||
"application": token.Application,
|
||||
"timestamp": GetStartOfHour(),
|
||||
},
|
||||
bson.M{
|
||||
"$setOnInsert": bson.M{
|
||||
"_id": RandomHexString(16),
|
||||
},
|
||||
"$inc": bson.M{
|
||||
"requestCount": 1,
|
||||
},
|
||||
},
|
||||
); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// SHA256 returns the result of hashing the input value using SHA256 algorithm.
|
||||
func SHA256(input string) string {
|
||||
result := sha1.Sum([]byte(input))
|
||||
@@ -325,3 +392,19 @@ func Map[I, O any](arr []I, f func(I) O) []O {
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetStartOfHour returns the current date and time rounded down to the start of the hour.
|
||||
func GetStartOfHour() time.Time {
|
||||
return time.Now().UTC().Truncate(time.Hour)
|
||||
}
|
||||
|
||||
// RandomHexString returns a random hexadecimal string with the specified byte length.
|
||||
func RandomHexString(byteLength int) string {
|
||||
data := make([]byte, byteLength)
|
||||
|
||||
if _, err := rand.Read(data); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return hex.EncodeToString(data)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user