Use Redis distributed locks
This commit is contained in:
34
src/main.go
34
src/main.go
@@ -2,10 +2,11 @@ package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
@@ -22,12 +23,15 @@ var (
|
||||
return ctx.SendStatus(http.StatusInternalServerError)
|
||||
},
|
||||
})
|
||||
r *Redis = &Redis{}
|
||||
conf *Config = DefaultConfig
|
||||
r *Redis = &Redis{}
|
||||
conf *Config = DefaultConfig
|
||||
instanceID uint16
|
||||
)
|
||||
|
||||
func init() {
|
||||
if err := conf.ReadFile("config.yml"); err != nil {
|
||||
var err error
|
||||
|
||||
if err = conf.ReadFile("config.yml"); err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
log.Printf("config.yml does not exist, writing default config\n")
|
||||
|
||||
@@ -39,14 +43,14 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
if err := GetBlockedServerList(); err != nil {
|
||||
if err = GetBlockedServerList(); err != nil {
|
||||
log.Fatalf("Failed to retrieve EULA blocked servers: %v", err)
|
||||
}
|
||||
|
||||
log.Println("Successfully retrieved EULA blocked servers")
|
||||
|
||||
if conf.Redis != nil {
|
||||
if err := r.Connect(); err != nil {
|
||||
if err = r.Connect(); err != nil {
|
||||
log.Fatalf("Failed to connect to Redis: %v", err)
|
||||
}
|
||||
|
||||
@@ -69,20 +73,20 @@ func init() {
|
||||
TimeFormat: "2006/01/02 15:04:05",
|
||||
}))
|
||||
}
|
||||
|
||||
if instanceID, err = GetInstanceID(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
defer r.Close()
|
||||
|
||||
instanceID, err := GetInstanceID()
|
||||
go ListenAndServe(conf.Host, conf.Port+instanceID)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer app.Shutdown()
|
||||
|
||||
log.Printf("Listening on %s:%d\n", conf.Host, conf.Port+instanceID)
|
||||
|
||||
if err := app.Listen(fmt.Sprintf("%s:%d", conf.Host, conf.Port+instanceID)); err != nil {
|
||||
log.Fatalf("failed to start server: %v", err)
|
||||
}
|
||||
s := make(chan os.Signal, 1)
|
||||
signal.Notify(s, os.Interrupt, syscall.SIGTERM)
|
||||
<-s
|
||||
}
|
||||
|
||||
65
src/redis.go
65
src/redis.go
@@ -5,14 +5,19 @@ import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/go-redsync/redsync/v4"
|
||||
redsyncredis "github.com/go-redsync/redsync/v4/redis"
|
||||
redsyncredislib "github.com/go-redsync/redsync/v4/redis/goredis/v9"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const defaultTimeout = 5 * time.Second
|
||||
|
||||
// Redis is a wrapper around the Redis client.
|
||||
type Redis struct {
|
||||
Client *redis.Client
|
||||
Client *redis.Client
|
||||
Pool *redsyncredis.Pool
|
||||
SyncClient *redsync.Redsync
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the Redis server using the configuration.
|
||||
@@ -21,6 +26,10 @@ func (r *Redis) Connect() error {
|
||||
return errors.New("missing Redis configuration")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
|
||||
|
||||
defer cancel()
|
||||
|
||||
opts, err := redis.ParseURL(*conf.Redis)
|
||||
|
||||
if err != nil {
|
||||
@@ -29,11 +38,15 @@ func (r *Redis) Connect() error {
|
||||
|
||||
r.Client = redis.NewClient(opts)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
|
||||
if err = r.Client.Ping(ctx).Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer cancel()
|
||||
pool := redsyncredislib.NewPool(r.Client)
|
||||
|
||||
return r.Client.Ping(ctx).Err()
|
||||
r.SyncClient = redsync.New(pool)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves the value and TTL for a given key.
|
||||
@@ -90,6 +103,19 @@ func (r *Redis) Increment(key string) error {
|
||||
return r.Client.Incr(ctx, key).Err()
|
||||
}
|
||||
|
||||
// NewMutex creates a new mutually exclusive lock that only one process can hold.
|
||||
func (r *Redis) NewMutex(name string) *Mutex {
|
||||
if r.Client == nil || r.SyncClient == nil {
|
||||
return &Mutex{
|
||||
Mutex: nil,
|
||||
}
|
||||
}
|
||||
|
||||
return &Mutex{
|
||||
Mutex: r.SyncClient.NewMutex(name),
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the Redis client connection.
|
||||
func (r *Redis) Close() error {
|
||||
if r.Client == nil {
|
||||
@@ -98,3 +124,32 @@ func (r *Redis) Close() error {
|
||||
|
||||
return r.Client.Close()
|
||||
}
|
||||
|
||||
// Mutex is a mutually exclusive lock held across all processes.
|
||||
type Mutex struct {
|
||||
Mutex *redsync.Mutex
|
||||
}
|
||||
|
||||
// Lock will lock the mutex so no other process can hold it.
|
||||
func (m *Mutex) Lock() error {
|
||||
if m.Mutex == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
|
||||
defer cancel()
|
||||
|
||||
return m.Mutex.LockContext(ctx)
|
||||
}
|
||||
|
||||
// Unlock will allow any other process to obtain a lock with the same key.
|
||||
func (m *Mutex) Unlock() error {
|
||||
if m.Mutex == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := m.Mutex.Unlock()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -30,13 +30,11 @@ func JavaStatusHandler(ctx *fiber.Ctx) error {
|
||||
return ctx.Status(http.StatusBadRequest).SendString("Invalid address value")
|
||||
}
|
||||
|
||||
enableQuery := ctx.QueryBool("query", true)
|
||||
|
||||
if err = r.Increment(fmt.Sprintf("java-hits:%s-%d", host, port)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, expiresAt, err := GetJavaStatus(host, port, enableQuery)
|
||||
response, expiresAt, err := GetJavaStatus(host, port, ctx.QueryBool("query", true))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
14
src/server.go
Normal file
14
src/server.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
func ListenAndServe(host string, port uint16) {
|
||||
log.Printf("Listening on %s:%d\n", host, port)
|
||||
|
||||
if err := app.Listen(fmt.Sprintf("%s:%d", host, port)); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -116,6 +116,11 @@ type Plugin struct {
|
||||
func GetJavaStatus(host string, port uint16, enableQuery bool) (*JavaStatusResponse, time.Duration, error) {
|
||||
cacheKey := fmt.Sprintf("java:%v-%s-%d", enableQuery, host, port)
|
||||
|
||||
mutex := r.NewMutex(fmt.Sprintf("java-lock:%v-%s-%d", enableQuery, host, port))
|
||||
mutex.Lock()
|
||||
|
||||
defer mutex.Unlock()
|
||||
|
||||
cache, ttl, err := r.Get(cacheKey)
|
||||
|
||||
if err != nil {
|
||||
@@ -149,6 +154,11 @@ func GetJavaStatus(host string, port uint16, enableQuery bool) (*JavaStatusRespo
|
||||
func GetBedrockStatus(host string, port uint16) (*BedrockStatusResponse, time.Duration, error) {
|
||||
cacheKey := fmt.Sprintf("bedrock:%s-%d", host, port)
|
||||
|
||||
mutex := r.NewMutex(fmt.Sprintf("bedrock-lock:%s-%d", host, port))
|
||||
mutex.Lock()
|
||||
|
||||
defer mutex.Unlock()
|
||||
|
||||
cache, ttl, err := r.Get(cacheKey)
|
||||
|
||||
if err != nil {
|
||||
|
||||
57
src/util.go
57
src/util.go
@@ -18,19 +18,20 @@ import (
|
||||
var (
|
||||
//go:embed icon.png
|
||||
defaultIconBytes []byte
|
||||
blockedServers *MutexArray = nil
|
||||
ipAddressRegex *regexp.Regexp = regexp.MustCompile(`^\d{1,3}(\.\d{1,3}){3}$`)
|
||||
blockedServers *MutexArray[string] = nil
|
||||
ipAddressRegex *regexp.Regexp = regexp.MustCompile(`^\d{1,3}(\.\d{1,3}){3}$`)
|
||||
)
|
||||
|
||||
// MutexArray is a thread-safe array for storing and checking values.
|
||||
type MutexArray struct {
|
||||
List []interface{}
|
||||
// MutexArray is a thread-safe array for storing and retrieving values.
|
||||
type MutexArray[T comparable] struct {
|
||||
List []T
|
||||
Mutex *sync.Mutex
|
||||
}
|
||||
|
||||
// Has checks if the given value is present in the array.
|
||||
func (m *MutexArray) Has(value interface{}) bool {
|
||||
func (m *MutexArray[T]) Has(value T) bool {
|
||||
m.Mutex.Lock()
|
||||
|
||||
defer m.Mutex.Unlock()
|
||||
|
||||
for _, v := range m.List {
|
||||
@@ -51,7 +52,7 @@ func GetBlockedServerList() error {
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
return fmt.Errorf("mojang: unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
@@ -62,16 +63,8 @@ func GetBlockedServerList() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert []string to []interface{}
|
||||
strSlice := strings.Split(string(body), "\n")
|
||||
interfaceSlice := make([]interface{}, len(strSlice))
|
||||
|
||||
for i, v := range strSlice {
|
||||
interfaceSlice[i] = v
|
||||
}
|
||||
|
||||
blockedServers = &MutexArray{
|
||||
List: interfaceSlice,
|
||||
blockedServers = &MutexArray[string]{
|
||||
List: strings.Split(string(body), "\n"),
|
||||
Mutex: &sync.Mutex{},
|
||||
}
|
||||
|
||||
@@ -80,24 +73,21 @@ func GetBlockedServerList() error {
|
||||
|
||||
// IsBlockedAddress checks if the given address is in the blocked servers list.
|
||||
func IsBlockedAddress(address string) bool {
|
||||
split := strings.Split(strings.ToLower(address), ".")
|
||||
isIPAddress := ipAddressRegex.MatchString(address)
|
||||
addressSegments := strings.Split(strings.ToLower(address), ".")
|
||||
isIPv4Address := ipAddressRegex.MatchString(address)
|
||||
|
||||
for k := range split {
|
||||
var newAddress string
|
||||
for i := range addressSegments {
|
||||
var checkAddress string
|
||||
|
||||
if k == 0 {
|
||||
newAddress = strings.Join(split, ".")
|
||||
} else if isIPAddress {
|
||||
newAddress = fmt.Sprintf("%s.*", strings.Join(split[0:len(split)-k], "."))
|
||||
if i == 0 {
|
||||
checkAddress = strings.Join(addressSegments, ".")
|
||||
} else if isIPv4Address {
|
||||
checkAddress = fmt.Sprintf("%s.*", strings.Join(addressSegments[0:len(addressSegments)-i], "."))
|
||||
} else {
|
||||
newAddress = fmt.Sprintf("*.%s", strings.Join(split[k:], "."))
|
||||
checkAddress = fmt.Sprintf("*.%s", strings.Join(addressSegments[i:], "."))
|
||||
}
|
||||
|
||||
newAddressBytes := sha1.Sum([]byte(newAddress))
|
||||
newAddressHash := hex.EncodeToString(newAddressBytes[:])
|
||||
|
||||
if blockedServers.Has(newAddressHash) {
|
||||
if blockedServers.Has(SHA256(checkAddress)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -141,6 +131,13 @@ func GetInstanceID() (uint16, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// SHA256 returns the result of hashing the input value using SHA256 algorithm.
|
||||
func SHA256(input string) string {
|
||||
result := sha1.Sum([]byte(input))
|
||||
|
||||
return hex.EncodeToString(result[:])
|
||||
}
|
||||
|
||||
// PointerOf returns a pointer of the argument passed.
|
||||
func PointerOf[T any](v T) *T {
|
||||
return &v
|
||||
|
||||
Reference in New Issue
Block a user