Code cleanup

This commit is contained in:
Jacob Gunther
2022-08-24 03:48:09 -05:00
parent d8ee96413a
commit 1d23a97732
4 changed files with 72 additions and 71 deletions

View File

@@ -1,7 +1,7 @@
package main package main
import ( import (
"io/ioutil" "os"
"time" "time"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
@@ -20,7 +20,7 @@ type Config struct {
} }
func (c *Config) ReadFile(file string) error { func (c *Config) ReadFile(file string) error {
data, err := ioutil.ReadFile(file) data, err := os.ReadFile(file)
if err != nil { if err != nil {
return err return err

View File

@@ -4,9 +4,6 @@ import (
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
"os"
"strconv"
"sync"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/cors"
@@ -14,15 +11,7 @@ import (
) )
var ( var (
app *fiber.App = nil app *fiber.App = fiber.New(fiber.Config{
r *Redis = &Redis{}
config *Config = &Config{}
blockedServers []string = nil
blockedServersMutex *sync.Mutex = &sync.Mutex{}
)
func init() {
app = fiber.New(fiber.Config{
DisableStartupMessage: true, DisableStartupMessage: true,
ErrorHandler: func(ctx *fiber.Ctx, err error) error { ErrorHandler: func(ctx *fiber.Ctx, err error) error {
log.Println(err) log.Println(err)
@@ -30,48 +19,44 @@ func init() {
return ctx.SendStatus(http.StatusInternalServerError) return ctx.SendStatus(http.StatusInternalServerError)
}, },
}) })
r *Redis = &Redis{}
config *Config = &Config{}
blockedServers *MutexArray[string] = nil
)
func init() {
if err := config.ReadFile("config.yml"); err != nil { if err := config.ReadFile("config.yml"); err != nil {
log.Fatal(err) log.Fatal(err)
} }
r.SetEnabled(config.Cache.Enable) if err := r.Connect(config.Redis); err != nil {
log.Fatal(err)
if config.Cache.Enable {
if err := r.Connect(config.Redis); err != nil {
log.Fatal(err)
}
log.Println("Successfully connected to Redis")
} }
log.Println("Successfully connected to Redis")
if err := GetBlockedServerList(); err != nil { if err := GetBlockedServerList(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
log.Println("Successfully retrieved EULA blocked servers") log.Println("Successfully retrieved EULA blocked servers")
if instanceID := os.Getenv("INSTANCE_ID"); len(instanceID) > 0 { app.Config()
value, err := strconv.ParseUint(instanceID, 10, 16) app.Use(recover.New())
if err != nil {
log.Fatal(err)
}
config.Port += uint16(value)
}
app.Use(cors.New(cors.Config{ app.Use(cors.New(cors.Config{
AllowOrigins: "*", AllowOrigins: "*",
AllowMethods: "HEAD,OPTIONS,GET", AllowMethods: "HEAD,OPTIONS,GET",
ExposeHeaders: "Content-Type,X-Cache-Time-Remaining", ExposeHeaders: "Content-Type,X-Cache-Time-Remaining",
})) }))
app.Use(recover.New())
} }
func main() { func main() {
log.Printf("Listening on %s:%d\n", config.Host, config.Port) instanceID, err := GetInstanceID()
log.Fatal(app.Listen(fmt.Sprintf("%s:%d", config.Host, config.Port))) if err != nil {
log.Fatal(err)
}
log.Printf("Listening on %s:%d\n", config.Host, config.Port+instanceID)
log.Fatal(app.Listen(fmt.Sprintf("%s:%d", config.Host, config.Port+instanceID)))
} }

View File

@@ -8,16 +8,11 @@ import (
) )
type Redis struct { type Redis struct {
Enabled bool Client *redis.Client
Client *redis.Client
}
func (r *Redis) SetEnabled(value bool) {
r.Enabled = value
} }
func (r *Redis) Connect(uri string) error { func (r *Redis) Connect(uri string) error {
if !r.Enabled { if !config.Cache.Enable {
return nil return nil
} }
@@ -37,7 +32,7 @@ func (r *Redis) Connect(uri string) error {
} }
func (r *Redis) Exists(key string) (bool, error) { func (r *Redis) Exists(key string) (bool, error) {
if !r.Enabled { if !config.Cache.Enable {
return false, nil return false, nil
} }
@@ -57,7 +52,7 @@ func (r *Redis) Exists(key string) (bool, error) {
} }
func (r *Redis) TTL(key string) (time.Duration, error) { func (r *Redis) TTL(key string) (time.Duration, error) {
if !r.Enabled { if !config.Cache.Enable {
return 0, nil return 0, nil
} }
@@ -75,7 +70,7 @@ func (r *Redis) TTL(key string) (time.Duration, error) {
} }
func (r *Redis) GetString(key string) (string, error) { func (r *Redis) GetString(key string) (string, error) {
if !r.Enabled { if !config.Cache.Enable {
return "", nil return "", nil
} }
@@ -93,7 +88,7 @@ func (r *Redis) GetString(key string) (string, error) {
} }
func (r *Redis) GetBytes(key string) ([]byte, error) { func (r *Redis) GetBytes(key string) ([]byte, error) {
if !r.Enabled { if !config.Cache.Enable {
return nil, nil return nil, nil
} }
@@ -111,7 +106,7 @@ func (r *Redis) GetBytes(key string) ([]byte, error) {
} }
func (r *Redis) Set(key string, value interface{}, ttl time.Duration) error { func (r *Redis) Set(key string, value interface{}, ttl time.Duration) error {
if !r.Enabled { if !config.Cache.Enable {
return nil return nil
} }
@@ -123,7 +118,7 @@ func (r *Redis) Set(key string, value interface{}, ttl time.Duration) error {
} }
func (r *Redis) Close() error { func (r *Redis) Close() error {
if !r.Enabled { if !config.Cache.Enable {
return nil return nil
} }

View File

@@ -5,11 +5,14 @@ import (
_ "embed" _ "embed"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io/ioutil" "io"
"log"
"net/http" "net/http"
"os"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"sync"
) )
var ( var (
@@ -18,16 +21,6 @@ var (
ipAddressRegExp = regexp.MustCompile(`^\d{1,3}(\.\d{1,3}){3}$`) ipAddressRegExp = regexp.MustCompile(`^\d{1,3}(\.\d{1,3}){3}$`)
) )
func Contains[T comparable](arr []T, v T) bool {
for _, value := range arr {
if v == value {
return true
}
}
return false
}
func GetBlockedServerList() error { func GetBlockedServerList() error {
resp, err := http.Get("https://sessionserver.mojang.com/blockedservers") resp, err := http.Get("https://sessionserver.mojang.com/blockedservers")
@@ -41,15 +34,16 @@ func GetBlockedServerList() error {
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return err return err
} }
blockedServersMutex.Lock() blockedServers = &MutexArray[string]{
blockedServers = strings.Split(string(body), "\n") List: strings.Split(string(body), "\n"),
blockedServersMutex.Unlock() Mutex: &sync.Mutex{},
}
return nil return nil
} }
@@ -84,15 +78,9 @@ func IsBlockedAddress(address string) bool {
newAddressBytes := sha1.Sum([]byte(newAddress)) newAddressBytes := sha1.Sum([]byte(newAddress))
newAddressHash := hex.EncodeToString(newAddressBytes[:]) newAddressHash := hex.EncodeToString(newAddressBytes[:])
blockedServersMutex.Lock() if blockedServers.Has(newAddressHash) {
if Contains(blockedServers, newAddressHash) {
blockedServersMutex.Unlock()
return true return true
} }
blockedServersMutex.Unlock()
} }
return false return false
@@ -117,3 +105,36 @@ func ParseAddress(address string, defaultPort uint16) (string, uint16, error) {
return result[0], uint16(port), nil return result[0], uint16(port), nil
} }
func GetInstanceID() (uint16, error) {
if instanceID := os.Getenv("INSTANCE_ID"); len(instanceID) > 0 {
value, err := strconv.ParseUint(instanceID, 10, 16)
if err != nil {
log.Fatal(err)
}
return uint16(value), nil
}
return 0, nil
}
type MutexArray[K comparable] struct {
List []K
Mutex *sync.Mutex
}
func (m *MutexArray[K]) Has(value K) bool {
m.Mutex.Lock()
defer m.Mutex.Unlock()
for _, v := range m.List {
if v == value {
return true
}
}
return false
}