From 3056a936c975449bc9cf32249172215352f351b8 Mon Sep 17 00:00:00 2001 From: Laura Date: Fri, 20 Jun 2025 22:04:34 +0200 Subject: [PATCH] rate limiting --- server/main.go | 16 ++++++++++++++ server/protocol.go | 28 +++++++++++++++++++++++++ server/rates.go | 52 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+) create mode 100644 server/rates.go diff --git a/server/main.go b/server/main.go index 181a959..1f8ce30 100644 --- a/server/main.go +++ b/server/main.go @@ -6,10 +6,13 @@ import ( "time" "github.com/coalaura/logger" + "github.com/coalaura/up/internal" "github.com/go-chi/chi/v5" "github.com/patrickmn/go-cache" ) +const MaxParallel = 8 + var ( log = logger.New().DetectTerminal().WithOptions(logger.Options{ NoLevel: true, @@ -17,9 +20,22 @@ var ( challenges = cache.New(10*time.Second, time.Minute) sessions = cache.New(10*time.Second, time.Minute) + rates = NewRateLimiter() ) func main() { + challenges.OnEvicted(func(_ string, entry interface{}) { + challenge := entry.(internal.ChallengeEntry) + + rates.Dec(challenge.Client) + }) + + sessions.OnEvicted(func(_ string, entry interface{}) { + session := entry.(internal.SessionEntry) + + rates.Dec(session.Client) + }) + authorized, err := LoadAuthorizedKeys() log.MustPanic(err) diff --git a/server/protocol.go b/server/protocol.go index e19f45d..8df27fd 100644 --- a/server/protocol.go +++ b/server/protocol.go @@ -48,6 +48,18 @@ func HandleChallengeRequest(w http.ResponseWriter, r *http.Request, authorized m log.Printf("request: received new request from %s\n", ip) + current, pass, fail := rates.Inc(ip) + defer fail() + + if current > MaxParallel { + w.WriteHeader(http.StatusTooManyRequests) + + log.Warning("request: too many requests") + log.WarningE(err) + + return + } + var request internal.AuthRequest reader := io.LimitReader(r.Body, 4096) @@ -91,6 +103,8 @@ func HandleChallengeRequest(w http.ResponseWriter, r *http.Request, authorized m PublicKey: public, }, cache.DefaultExpiration) + pass() + log.Printf("request: issued challenge to %s\n", ip) w.Header().Set("Content-Type", "application/msgpack") @@ -110,6 +124,18 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma log.Printf("complete: received completion from %s\n", ip) + current, pass, fail := rates.Inc(ip) + defer fail() + + if current > MaxParallel { + w.WriteHeader(http.StatusTooManyRequests) + + log.Warning("request: too many requests") + log.WarningE(err) + + return + } + var response internal.AuthResponse reader := io.LimitReader(r.Body, 4096) @@ -215,6 +241,8 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma Client: ip, }, cache.DefaultExpiration) + pass() + log.Printf("complete: authentication completed for %s\n", ip) w.Header().Set("Content-Type", "application/msgpack") diff --git a/server/rates.go b/server/rates.go new file mode 100644 index 0000000..87e0381 --- /dev/null +++ b/server/rates.go @@ -0,0 +1,52 @@ +package main + +import ( + "sync" + "sync/atomic" +) + +type RateLimiter struct { + sync.Map +} + +func NewRateLimiter() *RateLimiter { + return &RateLimiter{} +} + +func (rl *RateLimiter) Get(key string) *atomic.Uint32 { + val, ok := rl.Map.Load(key) + if ok { + return val.(*atomic.Uint32) + } + + actual, _ := rl.Map.LoadOrStore(key, &atomic.Uint32{}) + + return actual.(*atomic.Uint32) +} + +func (rl *RateLimiter) Inc(key string) (uint32, func(), func()) { + val := rl.Get(key) + new := val.Add(1) + + var done uint32 + + pass := func() { + atomic.CompareAndSwapUint32(&done, 0, 1) + } + + fail := func() { + if !atomic.CompareAndSwapUint32(&done, 0, 1) { + return + } + + val.Add(^uint32(0)) + } + + return new, pass, fail +} + +func (rl *RateLimiter) Dec(key string) uint32 { + val := rl.Get(key) + + return val.Add(^uint32(0)) +}