mirror of
https://github.com/coalaura/up.git
synced 2025-07-17 21:44:35 +00:00
global limit & log time
This commit is contained in:
@ -11,7 +11,13 @@ import (
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
const MaxParallel = 8
|
||||
const (
|
||||
// Max amount of parallel sessions/challenges per client
|
||||
MaxClientParallel = 8
|
||||
|
||||
// Max amount of parallel sessions/challenges overall
|
||||
MaxGlobalParallel = MaxClientParallel * 8
|
||||
)
|
||||
|
||||
var (
|
||||
log = logger.New().DetectTerminal().WithOptions(logger.Options{
|
||||
|
@ -5,11 +5,13 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/coalaura/up/internal"
|
||||
@ -51,7 +53,7 @@ func HandleChallengeRequest(w http.ResponseWriter, r *http.Request, authorized m
|
||||
current, pass, fail := rates.Inc(ip)
|
||||
defer fail()
|
||||
|
||||
if current > MaxParallel {
|
||||
if current > MaxClientParallel || current == 0 {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
|
||||
log.Warning("request: too many requests")
|
||||
@ -127,7 +129,7 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma
|
||||
current, pass, fail := rates.Inc(ip)
|
||||
defer fail()
|
||||
|
||||
if current > MaxParallel {
|
||||
if current > MaxClientParallel || current == 0 {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
|
||||
log.Warning("request: too many requests")
|
||||
@ -294,6 +296,8 @@ func HandleReceiveRequest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
|
||||
reader, err := r.MultipartReader()
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
@ -359,7 +363,7 @@ func HandleReceiveRequest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("receive: stored file %s from %s (%d bytes)\n", name, ip, read)
|
||||
log.Printf("receive: stored file %s from %s (%d bytes, took %s)\n", name, ip, read, RoundSince(time.Since(t0)))
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
@ -406,3 +410,10 @@ func SanitizeFilename(name string) string {
|
||||
|
||||
return cleaned.String()
|
||||
}
|
||||
|
||||
func RoundSince(since time.Duration) time.Duration {
|
||||
exp := int(math.Log10(float64(since)))
|
||||
by := time.Duration(math.Pow(10, max(0, float64(exp)-2)))
|
||||
|
||||
return since.Round(by)
|
||||
}
|
||||
|
@ -7,8 +7,11 @@ import (
|
||||
|
||||
type RateLimiter struct {
|
||||
sync.Map
|
||||
total atomic.Uint32
|
||||
}
|
||||
|
||||
const MinusOne uint32 = ^uint32(0)
|
||||
|
||||
func NewRateLimiter() *RateLimiter {
|
||||
return &RateLimiter{}
|
||||
}
|
||||
@ -25,6 +28,12 @@ func (rl *RateLimiter) Get(key string) *atomic.Uint32 {
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Inc(key string) (uint32, func(), func()) {
|
||||
if rl.total.Add(1) > MaxGlobalParallel {
|
||||
rl.total.Add(MinusOne)
|
||||
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
val := rl.Get(key)
|
||||
new := val.Add(1)
|
||||
|
||||
@ -39,14 +48,17 @@ func (rl *RateLimiter) Inc(key string) (uint32, func(), func()) {
|
||||
return
|
||||
}
|
||||
|
||||
val.Add(^uint32(0))
|
||||
rl.total.Add(MinusOne)
|
||||
val.Add(MinusOne)
|
||||
}
|
||||
|
||||
return new, pass, fail
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Dec(key string) uint32 {
|
||||
rl.total.Add(MinusOne)
|
||||
|
||||
val := rl.Get(key)
|
||||
|
||||
return val.Add(^uint32(0))
|
||||
return val.Add(MinusOne)
|
||||
}
|
||||
|
Reference in New Issue
Block a user