1
0
mirror of https://github.com/coalaura/up.git synced 2025-07-17 21:44:35 +00:00

rate limiting

This commit is contained in:
Laura
2025-06-20 22:04:34 +02:00
parent 8ec9a287ae
commit 3056a936c9
3 changed files with 96 additions and 0 deletions

View File

@ -6,10 +6,13 @@ import (
"time" "time"
"github.com/coalaura/logger" "github.com/coalaura/logger"
"github.com/coalaura/up/internal"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
) )
const MaxParallel = 8
var ( var (
log = logger.New().DetectTerminal().WithOptions(logger.Options{ log = logger.New().DetectTerminal().WithOptions(logger.Options{
NoLevel: true, NoLevel: true,
@ -17,9 +20,22 @@ var (
challenges = cache.New(10*time.Second, time.Minute) challenges = cache.New(10*time.Second, time.Minute)
sessions = cache.New(10*time.Second, time.Minute) sessions = cache.New(10*time.Second, time.Minute)
rates = NewRateLimiter()
) )
func main() { func main() {
challenges.OnEvicted(func(_ string, entry interface{}) {
challenge := entry.(internal.ChallengeEntry)
rates.Dec(challenge.Client)
})
sessions.OnEvicted(func(_ string, entry interface{}) {
session := entry.(internal.SessionEntry)
rates.Dec(session.Client)
})
authorized, err := LoadAuthorizedKeys() authorized, err := LoadAuthorizedKeys()
log.MustPanic(err) log.MustPanic(err)

View File

@ -48,6 +48,18 @@ func HandleChallengeRequest(w http.ResponseWriter, r *http.Request, authorized m
log.Printf("request: received new request from %s\n", ip) log.Printf("request: received new request from %s\n", ip)
current, pass, fail := rates.Inc(ip)
defer fail()
if current > MaxParallel {
w.WriteHeader(http.StatusTooManyRequests)
log.Warning("request: too many requests")
log.WarningE(err)
return
}
var request internal.AuthRequest var request internal.AuthRequest
reader := io.LimitReader(r.Body, 4096) reader := io.LimitReader(r.Body, 4096)
@ -91,6 +103,8 @@ func HandleChallengeRequest(w http.ResponseWriter, r *http.Request, authorized m
PublicKey: public, PublicKey: public,
}, cache.DefaultExpiration) }, cache.DefaultExpiration)
pass()
log.Printf("request: issued challenge to %s\n", ip) log.Printf("request: issued challenge to %s\n", ip)
w.Header().Set("Content-Type", "application/msgpack") w.Header().Set("Content-Type", "application/msgpack")
@ -110,6 +124,18 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma
log.Printf("complete: received completion from %s\n", ip) log.Printf("complete: received completion from %s\n", ip)
current, pass, fail := rates.Inc(ip)
defer fail()
if current > MaxParallel {
w.WriteHeader(http.StatusTooManyRequests)
log.Warning("request: too many requests")
log.WarningE(err)
return
}
var response internal.AuthResponse var response internal.AuthResponse
reader := io.LimitReader(r.Body, 4096) reader := io.LimitReader(r.Body, 4096)
@ -215,6 +241,8 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma
Client: ip, Client: ip,
}, cache.DefaultExpiration) }, cache.DefaultExpiration)
pass()
log.Printf("complete: authentication completed for %s\n", ip) log.Printf("complete: authentication completed for %s\n", ip)
w.Header().Set("Content-Type", "application/msgpack") w.Header().Set("Content-Type", "application/msgpack")

52
server/rates.go Normal file
View File

@ -0,0 +1,52 @@
package main
import (
"sync"
"sync/atomic"
)
type RateLimiter struct {
sync.Map
}
func NewRateLimiter() *RateLimiter {
return &RateLimiter{}
}
func (rl *RateLimiter) Get(key string) *atomic.Uint32 {
val, ok := rl.Map.Load(key)
if ok {
return val.(*atomic.Uint32)
}
actual, _ := rl.Map.LoadOrStore(key, &atomic.Uint32{})
return actual.(*atomic.Uint32)
}
func (rl *RateLimiter) Inc(key string) (uint32, func(), func()) {
val := rl.Get(key)
new := val.Add(1)
var done uint32
pass := func() {
atomic.CompareAndSwapUint32(&done, 0, 1)
}
fail := func() {
if !atomic.CompareAndSwapUint32(&done, 0, 1) {
return
}
val.Add(^uint32(0))
}
return new, pass, fail
}
func (rl *RateLimiter) Dec(key string) uint32 {
val := rl.Get(key)
return val.Add(^uint32(0))
}