diff --git a/src/config.go b/src/config.go index ec1205a..9641376 100644 --- a/src/config.go +++ b/src/config.go @@ -1,7 +1,7 @@ package main import ( - "io/ioutil" + "os" "time" "gopkg.in/yaml.v3" @@ -20,7 +20,7 @@ type Config struct { } func (c *Config) ReadFile(file string) error { - data, err := ioutil.ReadFile(file) + data, err := os.ReadFile(file) if err != nil { return err diff --git a/src/main.go b/src/main.go index 26db9da..94065f5 100644 --- a/src/main.go +++ b/src/main.go @@ -4,9 +4,6 @@ import ( "fmt" "log" "net/http" - "os" - "strconv" - "sync" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" @@ -14,15 +11,7 @@ import ( ) var ( - app *fiber.App = nil - r *Redis = &Redis{} - config *Config = &Config{} - blockedServers []string = nil - blockedServersMutex *sync.Mutex = &sync.Mutex{} -) - -func init() { - app = fiber.New(fiber.Config{ + app *fiber.App = fiber.New(fiber.Config{ DisableStartupMessage: true, ErrorHandler: func(ctx *fiber.Ctx, err error) error { log.Println(err) @@ -30,48 +19,44 @@ func init() { return ctx.SendStatus(http.StatusInternalServerError) }, }) + r *Redis = &Redis{} + config *Config = &Config{} + blockedServers *MutexArray[string] = nil +) +func init() { if err := config.ReadFile("config.yml"); err != nil { log.Fatal(err) } - r.SetEnabled(config.Cache.Enable) - - if config.Cache.Enable { - if err := r.Connect(config.Redis); err != nil { - log.Fatal(err) - } - - log.Println("Successfully connected to Redis") + if err := r.Connect(config.Redis); err != nil { + log.Fatal(err) } + log.Println("Successfully connected to Redis") + if err := GetBlockedServerList(); err != nil { log.Fatal(err) } log.Println("Successfully retrieved EULA blocked servers") - if instanceID := os.Getenv("INSTANCE_ID"); len(instanceID) > 0 { - value, err := strconv.ParseUint(instanceID, 10, 16) - - if err != nil { - log.Fatal(err) - } - - config.Port += uint16(value) - } - + app.Config() + app.Use(recover.New()) app.Use(cors.New(cors.Config{ AllowOrigins: "*", AllowMethods: "HEAD,OPTIONS,GET", ExposeHeaders: "Content-Type,X-Cache-Time-Remaining", })) - - app.Use(recover.New()) } func main() { - log.Printf("Listening on %s:%d\n", config.Host, config.Port) + instanceID, err := GetInstanceID() - log.Fatal(app.Listen(fmt.Sprintf("%s:%d", config.Host, config.Port))) + if err != nil { + log.Fatal(err) + } + + log.Printf("Listening on %s:%d\n", config.Host, config.Port+instanceID) + log.Fatal(app.Listen(fmt.Sprintf("%s:%d", config.Host, config.Port+instanceID))) } diff --git a/src/redis.go b/src/redis.go index e1cc953..2ef6f40 100644 --- a/src/redis.go +++ b/src/redis.go @@ -8,16 +8,11 @@ import ( ) type Redis struct { - Enabled bool - Client *redis.Client -} - -func (r *Redis) SetEnabled(value bool) { - r.Enabled = value + Client *redis.Client } func (r *Redis) Connect(uri string) error { - if !r.Enabled { + if !config.Cache.Enable { return nil } @@ -37,7 +32,7 @@ func (r *Redis) Connect(uri string) error { } func (r *Redis) Exists(key string) (bool, error) { - if !r.Enabled { + if !config.Cache.Enable { return false, nil } @@ -57,7 +52,7 @@ func (r *Redis) Exists(key string) (bool, error) { } func (r *Redis) TTL(key string) (time.Duration, error) { - if !r.Enabled { + if !config.Cache.Enable { return 0, nil } @@ -75,7 +70,7 @@ func (r *Redis) TTL(key string) (time.Duration, error) { } func (r *Redis) GetString(key string) (string, error) { - if !r.Enabled { + if !config.Cache.Enable { return "", nil } @@ -93,7 +88,7 @@ func (r *Redis) GetString(key string) (string, error) { } func (r *Redis) GetBytes(key string) ([]byte, error) { - if !r.Enabled { + if !config.Cache.Enable { return nil, nil } @@ -111,7 +106,7 @@ func (r *Redis) GetBytes(key string) ([]byte, error) { } func (r *Redis) Set(key string, value interface{}, ttl time.Duration) error { - if !r.Enabled { + if !config.Cache.Enable { return nil } @@ -123,7 +118,7 @@ func (r *Redis) Set(key string, value interface{}, ttl time.Duration) error { } func (r *Redis) Close() error { - if !r.Enabled { + if !config.Cache.Enable { return nil } diff --git a/src/util.go b/src/util.go index 9e7469e..baa3127 100644 --- a/src/util.go +++ b/src/util.go @@ -5,11 +5,14 @@ import ( _ "embed" "encoding/hex" "fmt" - "io/ioutil" + "io" + "log" "net/http" + "os" "regexp" "strconv" "strings" + "sync" ) var ( @@ -18,16 +21,6 @@ var ( ipAddressRegExp = regexp.MustCompile(`^\d{1,3}(\.\d{1,3}){3}$`) ) -func Contains[T comparable](arr []T, v T) bool { - for _, value := range arr { - if v == value { - return true - } - } - - return false -} - func GetBlockedServerList() error { resp, err := http.Get("https://sessionserver.mojang.com/blockedservers") @@ -41,15 +34,16 @@ func GetBlockedServerList() error { defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { return err } - blockedServersMutex.Lock() - blockedServers = strings.Split(string(body), "\n") - blockedServersMutex.Unlock() + blockedServers = &MutexArray[string]{ + List: strings.Split(string(body), "\n"), + Mutex: &sync.Mutex{}, + } return nil } @@ -84,15 +78,9 @@ func IsBlockedAddress(address string) bool { newAddressBytes := sha1.Sum([]byte(newAddress)) newAddressHash := hex.EncodeToString(newAddressBytes[:]) - blockedServersMutex.Lock() - - if Contains(blockedServers, newAddressHash) { - blockedServersMutex.Unlock() - + if blockedServers.Has(newAddressHash) { return true } - - blockedServersMutex.Unlock() } return false @@ -117,3 +105,36 @@ func ParseAddress(address string, defaultPort uint16) (string, uint16, error) { return result[0], uint16(port), nil } + +func GetInstanceID() (uint16, error) { + if instanceID := os.Getenv("INSTANCE_ID"); len(instanceID) > 0 { + value, err := strconv.ParseUint(instanceID, 10, 16) + + if err != nil { + log.Fatal(err) + } + + return uint16(value), nil + } + + return 0, nil +} + +type MutexArray[K comparable] struct { + List []K + Mutex *sync.Mutex +} + +func (m *MutexArray[K]) Has(value K) bool { + m.Mutex.Lock() + + defer m.Mutex.Unlock() + + for _, v := range m.List { + if v == value { + return true + } + } + + return false +}