From 9c1cfd5e7954546d8627d04d86fe8cca8d0d112e Mon Sep 17 00:00:00 2001 From: Laura Date: Fri, 20 Jun 2025 22:44:16 +0200 Subject: [PATCH] global limit & log time --- server/main.go | 8 +++++++- server/protocol.go | 17 ++++++++++++++--- server/rates.go | 16 ++++++++++++++-- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/server/main.go b/server/main.go index 1f8ce30..f138be0 100644 --- a/server/main.go +++ b/server/main.go @@ -11,7 +11,13 @@ import ( "github.com/patrickmn/go-cache" ) -const MaxParallel = 8 +const ( + // Max amount of parallel sessions/challenges per client + MaxClientParallel = 8 + + // Max amount of parallel sessions/challenges overall + MaxGlobalParallel = MaxClientParallel * 8 +) var ( log = logger.New().DetectTerminal().WithOptions(logger.Options{ diff --git a/server/protocol.go b/server/protocol.go index 8df27fd..fab30bd 100644 --- a/server/protocol.go +++ b/server/protocol.go @@ -5,11 +5,13 @@ import ( "encoding/base64" "errors" "io" + "math" "net" "net/http" "os" "path/filepath" "strings" + "time" "unicode" "github.com/coalaura/up/internal" @@ -51,7 +53,7 @@ func HandleChallengeRequest(w http.ResponseWriter, r *http.Request, authorized m current, pass, fail := rates.Inc(ip) defer fail() - if current > MaxParallel { + if current > MaxClientParallel || current == 0 { w.WriteHeader(http.StatusTooManyRequests) log.Warning("request: too many requests") @@ -127,7 +129,7 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma current, pass, fail := rates.Inc(ip) defer fail() - if current > MaxParallel { + if current > MaxClientParallel || current == 0 { w.WriteHeader(http.StatusTooManyRequests) log.Warning("request: too many requests") @@ -294,6 +296,8 @@ func HandleReceiveRequest(w http.ResponseWriter, r *http.Request) { return } + t0 := time.Now() + reader, err := r.MultipartReader() if err != nil { w.WriteHeader(http.StatusBadRequest) @@ -359,7 +363,7 @@ func HandleReceiveRequest(w http.ResponseWriter, r *http.Request) { return } - log.Printf("receive: stored file %s from %s (%d bytes)\n", name, ip, read) + log.Printf("receive: stored file %s from %s (%d bytes, took %s)\n", name, ip, read, RoundSince(time.Since(t0))) w.WriteHeader(http.StatusOK) } @@ -406,3 +410,10 @@ func SanitizeFilename(name string) string { return cleaned.String() } + +func RoundSince(since time.Duration) time.Duration { + exp := int(math.Log10(float64(since))) + by := time.Duration(math.Pow(10, max(0, float64(exp)-2))) + + return since.Round(by) +} diff --git a/server/rates.go b/server/rates.go index 87e0381..5677c47 100644 --- a/server/rates.go +++ b/server/rates.go @@ -7,8 +7,11 @@ import ( type RateLimiter struct { sync.Map + total atomic.Uint32 } +const MinusOne uint32 = ^uint32(0) + func NewRateLimiter() *RateLimiter { return &RateLimiter{} } @@ -25,6 +28,12 @@ func (rl *RateLimiter) Get(key string) *atomic.Uint32 { } func (rl *RateLimiter) Inc(key string) (uint32, func(), func()) { + if rl.total.Add(1) > MaxGlobalParallel { + rl.total.Add(MinusOne) + + return 0, nil, nil + } + val := rl.Get(key) new := val.Add(1) @@ -39,14 +48,17 @@ func (rl *RateLimiter) Inc(key string) (uint32, func(), func()) { return } - val.Add(^uint32(0)) + rl.total.Add(MinusOne) + val.Add(MinusOne) } return new, pass, fail } func (rl *RateLimiter) Dec(key string) uint32 { + rl.total.Add(MinusOne) + val := rl.Get(key) - return val.Add(^uint32(0)) + return val.Add(MinusOne) }