diff --git a/server/protocol.go b/server/protocol.go index 8756421..62be1df 100644 --- a/server/protocol.go +++ b/server/protocol.go @@ -18,6 +18,7 @@ import ( ) var ( + UnauthorizedErr = errors.New("unauthorized key") SignatureFormats = map[string]bool{ "ssh-ed25519": true, "ssh-rsa": true, @@ -49,7 +50,11 @@ func HandleChallengeRequest(w http.ResponseWriter, r *http.Request, authorized m public, err := DecodeAndAuthorizePublicKey(request.Public, authorized) if err != nil { - w.WriteHeader(http.StatusBadRequest) + if err == UnauthorizedErr { + w.WriteHeader(http.StatusUnauthorized) + } else { + w.WriteHeader(http.StatusBadRequest) + } log.Warning("request: failed to parse or authorize public key") log.WarningE(err) @@ -94,7 +99,11 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma public, err := DecodeAndAuthorizePublicKey(response.Public, authorized) if err != nil { - w.WriteHeader(http.StatusBadRequest) + if err == UnauthorizedErr { + w.WriteHeader(http.StatusUnauthorized) + } else { + w.WriteHeader(http.StatusBadRequest) + } log.Warning("complete: failed to parse or authorize public key") log.WarningE(err) @@ -104,7 +113,7 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma entry, ok := challenges.Get(response.Token) if !ok { - w.WriteHeader(http.StatusBadRequest) + w.WriteHeader(http.StatusUnauthorized) log.Warning("complete: invalid challenge token") @@ -119,7 +128,7 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma publicB := challenge.PublicKey.Marshal() if !bytes.Equal(publicA, publicB) { - w.WriteHeader(http.StatusBadRequest) + w.WriteHeader(http.StatusUnauthorized) log.Warning("complete: incorrect public key") @@ -150,7 +159,7 @@ func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized ma } if err = public.Verify(challenge.Challenge, sig); err != nil { - w.WriteHeader(http.StatusBadRequest) + w.WriteHeader(http.StatusUnauthorized) log.Warning("complete: failed to verify signature") log.WarningE(err) @@ -185,7 +194,7 @@ func HandleReceiveRequest(w http.ResponseWriter, r *http.Request) { token := r.Header.Get("Authorization") if token == "" { - w.WriteHeader(http.StatusBadRequest) + w.WriteHeader(http.StatusUnauthorized) log.Warning("receive: missing token") @@ -193,7 +202,7 @@ func HandleReceiveRequest(w http.ResponseWriter, r *http.Request) { } if _, ok := sessions.Get(token); !ok { - w.WriteHeader(http.StatusBadRequest) + w.WriteHeader(http.StatusUnauthorized) log.Warning("receive: invalid token") @@ -284,7 +293,7 @@ func DecodeAndAuthorizePublicKey(public string, authorized map[string]ssh.Public } if _, ok := authorized[string(key.Marshal())]; !ok { - return nil, errors.New("unauthorized key") + return nil, UnauthorizedErr } return key, nil