diff --git a/go.mod b/go.mod index 16ca091..1841a81 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.24.2 require ( github.com/coalaura/logger v1.4.5 github.com/go-chi/chi/v5 v5.2.2 + github.com/patrickmn/go-cache v2.1.0+incompatible github.com/urfave/cli/v3 v3.3.8 github.com/vmihailenco/msgpack/v5 v5.4.1 golang.org/x/crypto v0.39.0 diff --git a/go.sum b/go.sum index bfcfdec..6bb58f2 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/go-chi/chi/v5 v5.2.2 h1:CMwsvRVTbXVytCk1Wd72Zy1LAsAh9GxMmSNWLHCG618= github.com/go-chi/chi/v5 v5.2.2/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= diff --git a/internal/types.go b/internal/types.go index a059cef..fad0e89 100644 --- a/internal/types.go +++ b/internal/types.go @@ -1,20 +1,16 @@ package internal import ( - "time" - "golang.org/x/crypto/ssh" ) type ChallengeEntry struct { Challenge []byte PublicKey ssh.PublicKey - Expires time.Time } type SessionEntry struct { PublicKey ssh.PublicKey - Expires time.Time } type AuthRequest struct { diff --git a/server/connection.go b/server/connection.go deleted file mode 100644 index 225ee09..0000000 --- a/server/connection.go +++ /dev/null @@ -1,12 +0,0 @@ -package main - -import ( - "net" - "time" -) - -func HandleConnection(conn net.Conn) error { - time.Sleep(10 * time.Second) - - return nil -} diff --git a/server/main.go b/server/main.go index 83fa37f..1298b37 100644 --- a/server/main.go +++ b/server/main.go @@ -2,10 +2,11 @@ package main import ( "net/http" - "sync" + "time" "github.com/coalaura/logger" "github.com/go-chi/chi/v5" + "github.com/patrickmn/go-cache" ) var ( @@ -13,8 +14,8 @@ var ( NoLevel: true, }) - challenges sync.Map - sessions sync.Map + challenges = cache.New(10*time.Second, time.Minute) + sessions = cache.New(10*time.Second, time.Minute) ) func main() { diff --git a/server/protocol.go b/server/protocol.go index 303fbea..1122349 100644 --- a/server/protocol.go +++ b/server/protocol.go @@ -8,9 +8,9 @@ import ( "net/http" "os" "path/filepath" - "time" "github.com/coalaura/up/internal" + "github.com/patrickmn/go-cache" "github.com/vmihailenco/msgpack/v5" "golang.org/x/crypto/ssh" ) @@ -65,11 +65,10 @@ func HandleChallengeRequest(w http.ResponseWriter, r *http.Request, authorized m return } - challenges.Store(challenge.Token, internal.ChallengeEntry{ + challenges.Set(challenge.Token, internal.ChallengeEntry{ Challenge: raw, PublicKey: public, - Expires: time.Now().Add(20 * time.Second), - }) + }, cache.DefaultExpiration) log.Printf("request: issued challenge to %s\n", r.RemoteAddr) @@ -101,7 +100,7 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma return } - entry, ok := challenges.LoadAndDelete(response.Token) + entry, ok := challenges.Get(response.Token) if !ok { w.WriteHeader(http.StatusBadRequest) @@ -110,16 +109,10 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma return } + challenges.Delete(response.Token) + challenge := entry.(internal.ChallengeEntry) - if time.Now().After(challenge.Expires) { - w.WriteHeader(http.StatusBadRequest) - - log.Warning("complete: challenge expired") - - return - } - publicA := public.Marshal() publicB := challenge.PublicKey.Marshal() @@ -173,10 +166,9 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma return } - sessions.Store(token, internal.SessionEntry{ + sessions.Set(token, internal.SessionEntry{ PublicKey: public, - Expires: time.Now().Add(5 * time.Minute), - }) + }, cache.DefaultExpiration) log.Printf("complete: completed auth for %s\n", r.RemoteAddr) @@ -198,8 +190,7 @@ func HandleReceiveRequest(w http.ResponseWriter, r *http.Request) { return } - entry, ok := sessions.LoadAndDelete(token) - if !ok { + if _, ok := sessions.Get(token); !ok { w.WriteHeader(http.StatusBadRequest) log.Warning("receive: invalid token") @@ -207,15 +198,7 @@ func HandleReceiveRequest(w http.ResponseWriter, r *http.Request) { return } - session := entry.(internal.SessionEntry) - - if time.Now().After(session.Expires) { - w.WriteHeader(http.StatusBadRequest) - - log.Warning("receive: session expired") - - return - } + sessions.Delete(token) reader, err := r.MultipartReader() if err != nil {