Files
ping-server/src/util.go
2023-07-28 15:27:42 -05:00

161 lines
3.6 KiB
Go

package main
import (
"crypto/sha1"
_ "embed"
"encoding/hex"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"regexp"
"strconv"
"strings"
"sync"
)
var (
blockedServers *MutexArray[string] = nil
hostRegEx *regexp.Regexp = regexp.MustCompile(`^[A-Za-z0-9-]+(\.[A-Za-z0-9-]+)+(:\d{1,5})?$`)
ipAddressRegEx *regexp.Regexp = regexp.MustCompile(`^\d{1,3}(\.\d{1,3}){3}$`)
)
// MutexArray is a thread-safe array for storing and retrieving values.
type MutexArray[T comparable] struct {
List []T
Mutex *sync.Mutex
}
// Has checks if the given value is present in the array.
func (m *MutexArray[T]) Has(value T) bool {
m.Mutex.Lock()
defer m.Mutex.Unlock()
for _, v := range m.List {
if v == value {
return true
}
}
return false
}
// GetBlockedServerList fetches the list of blocked servers from Mojang's session server.
func GetBlockedServerList() error {
resp, err := http.Get("https://sessionserver.mojang.com/blockedservers")
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("mojang: unexpected status code: %d", resp.StatusCode)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
blockedServers = &MutexArray[string]{
List: strings.Split(string(body), "\n"),
Mutex: &sync.Mutex{},
}
return nil
}
// IsBlockedAddress checks if the given address is in the blocked servers list.
func IsBlockedAddress(address string) bool {
addressSegments := strings.Split(strings.ToLower(address), ".")
isIPv4Address := ipAddressRegEx.MatchString(address)
for i := range addressSegments {
var checkAddress string
if i == 0 {
checkAddress = strings.Join(addressSegments, ".")
} else if isIPv4Address {
checkAddress = fmt.Sprintf("%s.*", strings.Join(addressSegments[0:len(addressSegments)-i], "."))
} else {
checkAddress = fmt.Sprintf("*.%s", strings.Join(addressSegments[i:], "."))
}
if blockedServers.Has(SHA256(checkAddress)) {
return true
}
}
return false
}
// ParseAddress extracts the hostname and port from the given address string, and returns the default port if none is provided.
func ParseAddress(address string, defaultPort uint16) (string, uint16, error) {
if !hostRegEx.MatchString(address) {
return "", 0, fmt.Errorf("'%s' does not match any known address", address)
}
splitHost := strings.SplitN(address, ":", 2)
if len(splitHost) < 1 {
return "", 0, fmt.Errorf("'%s' does not match any known address", address)
}
host := splitHost[0]
if len(splitHost) < 2 {
return host, defaultPort, nil
}
port, err := strconv.ParseUint(splitHost[1], 10, 16)
if err != nil {
return "", 0, err
}
return host, uint16(port), nil
}
// GetInstanceID returns the INSTANCE_ID environment variable parsed as an unsigned 16-bit integer.
func GetInstanceID() (uint16, error) {
if instanceID := os.Getenv("INSTANCE_ID"); len(instanceID) > 0 {
value, err := strconv.ParseUint(instanceID, 10, 16)
if err != nil {
log.Fatal(err)
}
return uint16(value), nil
}
return 0, nil
}
// GetCacheKey generates a unique key used for caching status results in Redis.
func GetCacheKey(host string, port uint16, query bool) string {
values := &url.Values{}
values.Set("host", host)
values.Set("port", strconv.FormatUint(uint64(port), 10))
values.Set("query", strconv.FormatBool(query))
return SHA256(values.Encode())
}
// SHA256 returns the result of hashing the input value using SHA256 algorithm.
func SHA256(input string) string {
result := sha1.Sum([]byte(input))
return hex.EncodeToString(result[:])
}
// PointerOf returns a pointer of the argument passed.
func PointerOf[T any](v T) *T {
return &v
}