Use Redis distributed locks

This commit is contained in:
Jacob Gunther
2023-07-19 21:08:58 -05:00
parent 2f46ca5611
commit 8554fd2064
8 changed files with 251 additions and 53 deletions

View File

@@ -2,10 +2,11 @@ package main
import (
"errors"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
@@ -22,12 +23,15 @@ var (
return ctx.SendStatus(http.StatusInternalServerError)
},
})
r *Redis = &Redis{}
conf *Config = DefaultConfig
r *Redis = &Redis{}
conf *Config = DefaultConfig
instanceID uint16
)
func init() {
if err := conf.ReadFile("config.yml"); err != nil {
var err error
if err = conf.ReadFile("config.yml"); err != nil {
if errors.Is(err, os.ErrNotExist) {
log.Printf("config.yml does not exist, writing default config\n")
@@ -39,14 +43,14 @@ func init() {
}
}
if err := GetBlockedServerList(); err != nil {
if err = GetBlockedServerList(); err != nil {
log.Fatalf("Failed to retrieve EULA blocked servers: %v", err)
}
log.Println("Successfully retrieved EULA blocked servers")
if conf.Redis != nil {
if err := r.Connect(); err != nil {
if err = r.Connect(); err != nil {
log.Fatalf("Failed to connect to Redis: %v", err)
}
@@ -69,20 +73,20 @@ func init() {
TimeFormat: "2006/01/02 15:04:05",
}))
}
if instanceID, err = GetInstanceID(); err != nil {
panic(err)
}
}
func main() {
defer r.Close()
instanceID, err := GetInstanceID()
go ListenAndServe(conf.Host, conf.Port+instanceID)
if err != nil {
log.Fatal(err)
}
defer app.Shutdown()
log.Printf("Listening on %s:%d\n", conf.Host, conf.Port+instanceID)
if err := app.Listen(fmt.Sprintf("%s:%d", conf.Host, conf.Port+instanceID)); err != nil {
log.Fatalf("failed to start server: %v", err)
}
s := make(chan os.Signal, 1)
signal.Notify(s, os.Interrupt, syscall.SIGTERM)
<-s
}

View File

@@ -5,14 +5,19 @@ import (
"errors"
"time"
"github.com/go-redis/redis/v8"
"github.com/go-redsync/redsync/v4"
redsyncredis "github.com/go-redsync/redsync/v4/redis"
redsyncredislib "github.com/go-redsync/redsync/v4/redis/goredis/v9"
"github.com/redis/go-redis/v9"
)
const defaultTimeout = 5 * time.Second
// Redis is a wrapper around the Redis client.
type Redis struct {
Client *redis.Client
Client *redis.Client
Pool *redsyncredis.Pool
SyncClient *redsync.Redsync
}
// Connect establishes a connection to the Redis server using the configuration.
@@ -21,6 +26,10 @@ func (r *Redis) Connect() error {
return errors.New("missing Redis configuration")
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
opts, err := redis.ParseURL(*conf.Redis)
if err != nil {
@@ -29,11 +38,15 @@ func (r *Redis) Connect() error {
r.Client = redis.NewClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
if err = r.Client.Ping(ctx).Err(); err != nil {
return err
}
defer cancel()
pool := redsyncredislib.NewPool(r.Client)
return r.Client.Ping(ctx).Err()
r.SyncClient = redsync.New(pool)
return nil
}
// Get retrieves the value and TTL for a given key.
@@ -90,6 +103,19 @@ func (r *Redis) Increment(key string) error {
return r.Client.Incr(ctx, key).Err()
}
// NewMutex creates a new mutually exclusive lock that only one process can hold.
func (r *Redis) NewMutex(name string) *Mutex {
if r.Client == nil || r.SyncClient == nil {
return &Mutex{
Mutex: nil,
}
}
return &Mutex{
Mutex: r.SyncClient.NewMutex(name),
}
}
// Close closes the Redis client connection.
func (r *Redis) Close() error {
if r.Client == nil {
@@ -98,3 +124,32 @@ func (r *Redis) Close() error {
return r.Client.Close()
}
// Mutex is a mutually exclusive lock held across all processes.
type Mutex struct {
Mutex *redsync.Mutex
}
// Lock will lock the mutex so no other process can hold it.
func (m *Mutex) Lock() error {
if m.Mutex == nil {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
return m.Mutex.LockContext(ctx)
}
// Unlock will allow any other process to obtain a lock with the same key.
func (m *Mutex) Unlock() error {
if m.Mutex == nil {
return nil
}
_, err := m.Mutex.Unlock()
return err
}

View File

@@ -30,13 +30,11 @@ func JavaStatusHandler(ctx *fiber.Ctx) error {
return ctx.Status(http.StatusBadRequest).SendString("Invalid address value")
}
enableQuery := ctx.QueryBool("query", true)
if err = r.Increment(fmt.Sprintf("java-hits:%s-%d", host, port)); err != nil {
return err
}
response, expiresAt, err := GetJavaStatus(host, port, enableQuery)
response, expiresAt, err := GetJavaStatus(host, port, ctx.QueryBool("query", true))
if err != nil {
return err

14
src/server.go Normal file
View File

@@ -0,0 +1,14 @@
package main
import (
"fmt"
"log"
)
func ListenAndServe(host string, port uint16) {
log.Printf("Listening on %s:%d\n", host, port)
if err := app.Listen(fmt.Sprintf("%s:%d", host, port)); err != nil {
panic(err)
}
}

View File

@@ -116,6 +116,11 @@ type Plugin struct {
func GetJavaStatus(host string, port uint16, enableQuery bool) (*JavaStatusResponse, time.Duration, error) {
cacheKey := fmt.Sprintf("java:%v-%s-%d", enableQuery, host, port)
mutex := r.NewMutex(fmt.Sprintf("java-lock:%v-%s-%d", enableQuery, host, port))
mutex.Lock()
defer mutex.Unlock()
cache, ttl, err := r.Get(cacheKey)
if err != nil {
@@ -149,6 +154,11 @@ func GetJavaStatus(host string, port uint16, enableQuery bool) (*JavaStatusRespo
func GetBedrockStatus(host string, port uint16) (*BedrockStatusResponse, time.Duration, error) {
cacheKey := fmt.Sprintf("bedrock:%s-%d", host, port)
mutex := r.NewMutex(fmt.Sprintf("bedrock-lock:%s-%d", host, port))
mutex.Lock()
defer mutex.Unlock()
cache, ttl, err := r.Get(cacheKey)
if err != nil {

View File

@@ -18,19 +18,20 @@ import (
var (
//go:embed icon.png
defaultIconBytes []byte
blockedServers *MutexArray = nil
ipAddressRegex *regexp.Regexp = regexp.MustCompile(`^\d{1,3}(\.\d{1,3}){3}$`)
blockedServers *MutexArray[string] = nil
ipAddressRegex *regexp.Regexp = regexp.MustCompile(`^\d{1,3}(\.\d{1,3}){3}$`)
)
// MutexArray is a thread-safe array for storing and checking values.
type MutexArray struct {
List []interface{}
// MutexArray is a thread-safe array for storing and retrieving values.
type MutexArray[T comparable] struct {
List []T
Mutex *sync.Mutex
}
// Has checks if the given value is present in the array.
func (m *MutexArray) Has(value interface{}) bool {
func (m *MutexArray[T]) Has(value T) bool {
m.Mutex.Lock()
defer m.Mutex.Unlock()
for _, v := range m.List {
@@ -51,7 +52,7 @@ func GetBlockedServerList() error {
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
return fmt.Errorf("mojang: unexpected status code: %d", resp.StatusCode)
}
defer resp.Body.Close()
@@ -62,16 +63,8 @@ func GetBlockedServerList() error {
return err
}
// Convert []string to []interface{}
strSlice := strings.Split(string(body), "\n")
interfaceSlice := make([]interface{}, len(strSlice))
for i, v := range strSlice {
interfaceSlice[i] = v
}
blockedServers = &MutexArray{
List: interfaceSlice,
blockedServers = &MutexArray[string]{
List: strings.Split(string(body), "\n"),
Mutex: &sync.Mutex{},
}
@@ -80,24 +73,21 @@ func GetBlockedServerList() error {
// IsBlockedAddress checks if the given address is in the blocked servers list.
func IsBlockedAddress(address string) bool {
split := strings.Split(strings.ToLower(address), ".")
isIPAddress := ipAddressRegex.MatchString(address)
addressSegments := strings.Split(strings.ToLower(address), ".")
isIPv4Address := ipAddressRegex.MatchString(address)
for k := range split {
var newAddress string
for i := range addressSegments {
var checkAddress string
if k == 0 {
newAddress = strings.Join(split, ".")
} else if isIPAddress {
newAddress = fmt.Sprintf("%s.*", strings.Join(split[0:len(split)-k], "."))
if i == 0 {
checkAddress = strings.Join(addressSegments, ".")
} else if isIPv4Address {
checkAddress = fmt.Sprintf("%s.*", strings.Join(addressSegments[0:len(addressSegments)-i], "."))
} else {
newAddress = fmt.Sprintf("*.%s", strings.Join(split[k:], "."))
checkAddress = fmt.Sprintf("*.%s", strings.Join(addressSegments[i:], "."))
}
newAddressBytes := sha1.Sum([]byte(newAddress))
newAddressHash := hex.EncodeToString(newAddressBytes[:])
if blockedServers.Has(newAddressHash) {
if blockedServers.Has(SHA256(checkAddress)) {
return true
}
}
@@ -141,6 +131,13 @@ func GetInstanceID() (uint16, error) {
return 0, nil
}
// SHA256 returns the result of hashing the input value using SHA256 algorithm.
func SHA256(input string) string {
result := sha1.Sum([]byte(input))
return hex.EncodeToString(result[:])
}
// PointerOf returns a pointer of the argument passed.
func PointerOf[T any](v T) *T {
return &v