mirror of
https://github.com/coalaura/up.git
synced 2025-07-17 21:44:35 +00:00
rate limiting
This commit is contained in:
@ -6,10 +6,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coalaura/logger"
|
"github.com/coalaura/logger"
|
||||||
|
"github.com/coalaura/up/internal"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const MaxParallel = 8
|
||||||
|
|
||||||
var (
|
var (
|
||||||
log = logger.New().DetectTerminal().WithOptions(logger.Options{
|
log = logger.New().DetectTerminal().WithOptions(logger.Options{
|
||||||
NoLevel: true,
|
NoLevel: true,
|
||||||
@ -17,9 +20,22 @@ var (
|
|||||||
|
|
||||||
challenges = cache.New(10*time.Second, time.Minute)
|
challenges = cache.New(10*time.Second, time.Minute)
|
||||||
sessions = cache.New(10*time.Second, time.Minute)
|
sessions = cache.New(10*time.Second, time.Minute)
|
||||||
|
rates = NewRateLimiter()
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
challenges.OnEvicted(func(_ string, entry interface{}) {
|
||||||
|
challenge := entry.(internal.ChallengeEntry)
|
||||||
|
|
||||||
|
rates.Dec(challenge.Client)
|
||||||
|
})
|
||||||
|
|
||||||
|
sessions.OnEvicted(func(_ string, entry interface{}) {
|
||||||
|
session := entry.(internal.SessionEntry)
|
||||||
|
|
||||||
|
rates.Dec(session.Client)
|
||||||
|
})
|
||||||
|
|
||||||
authorized, err := LoadAuthorizedKeys()
|
authorized, err := LoadAuthorizedKeys()
|
||||||
log.MustPanic(err)
|
log.MustPanic(err)
|
||||||
|
|
||||||
|
@ -48,6 +48,18 @@ func HandleChallengeRequest(w http.ResponseWriter, r *http.Request, authorized m
|
|||||||
|
|
||||||
log.Printf("request: received new request from %s\n", ip)
|
log.Printf("request: received new request from %s\n", ip)
|
||||||
|
|
||||||
|
current, pass, fail := rates.Inc(ip)
|
||||||
|
defer fail()
|
||||||
|
|
||||||
|
if current > MaxParallel {
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
|
||||||
|
log.Warning("request: too many requests")
|
||||||
|
log.WarningE(err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var request internal.AuthRequest
|
var request internal.AuthRequest
|
||||||
|
|
||||||
reader := io.LimitReader(r.Body, 4096)
|
reader := io.LimitReader(r.Body, 4096)
|
||||||
@ -91,6 +103,8 @@ func HandleChallengeRequest(w http.ResponseWriter, r *http.Request, authorized m
|
|||||||
PublicKey: public,
|
PublicKey: public,
|
||||||
}, cache.DefaultExpiration)
|
}, cache.DefaultExpiration)
|
||||||
|
|
||||||
|
pass()
|
||||||
|
|
||||||
log.Printf("request: issued challenge to %s\n", ip)
|
log.Printf("request: issued challenge to %s\n", ip)
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/msgpack")
|
w.Header().Set("Content-Type", "application/msgpack")
|
||||||
@ -110,6 +124,18 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma
|
|||||||
|
|
||||||
log.Printf("complete: received completion from %s\n", ip)
|
log.Printf("complete: received completion from %s\n", ip)
|
||||||
|
|
||||||
|
current, pass, fail := rates.Inc(ip)
|
||||||
|
defer fail()
|
||||||
|
|
||||||
|
if current > MaxParallel {
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
|
||||||
|
log.Warning("request: too many requests")
|
||||||
|
log.WarningE(err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var response internal.AuthResponse
|
var response internal.AuthResponse
|
||||||
|
|
||||||
reader := io.LimitReader(r.Body, 4096)
|
reader := io.LimitReader(r.Body, 4096)
|
||||||
@ -215,6 +241,8 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma
|
|||||||
Client: ip,
|
Client: ip,
|
||||||
}, cache.DefaultExpiration)
|
}, cache.DefaultExpiration)
|
||||||
|
|
||||||
|
pass()
|
||||||
|
|
||||||
log.Printf("complete: authentication completed for %s\n", ip)
|
log.Printf("complete: authentication completed for %s\n", ip)
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/msgpack")
|
w.Header().Set("Content-Type", "application/msgpack")
|
||||||
|
52
server/rates.go
Normal file
52
server/rates.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RateLimiter struct {
|
||||||
|
sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRateLimiter() *RateLimiter {
|
||||||
|
return &RateLimiter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *RateLimiter) Get(key string) *atomic.Uint32 {
|
||||||
|
val, ok := rl.Map.Load(key)
|
||||||
|
if ok {
|
||||||
|
return val.(*atomic.Uint32)
|
||||||
|
}
|
||||||
|
|
||||||
|
actual, _ := rl.Map.LoadOrStore(key, &atomic.Uint32{})
|
||||||
|
|
||||||
|
return actual.(*atomic.Uint32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *RateLimiter) Inc(key string) (uint32, func(), func()) {
|
||||||
|
val := rl.Get(key)
|
||||||
|
new := val.Add(1)
|
||||||
|
|
||||||
|
var done uint32
|
||||||
|
|
||||||
|
pass := func() {
|
||||||
|
atomic.CompareAndSwapUint32(&done, 0, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fail := func() {
|
||||||
|
if !atomic.CompareAndSwapUint32(&done, 0, 1) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
val.Add(^uint32(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
return new, pass, fail
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *RateLimiter) Dec(key string) uint32 {
|
||||||
|
val := rl.Get(key)
|
||||||
|
|
||||||
|
return val.Add(^uint32(0))
|
||||||
|
}
|
Reference in New Issue
Block a user