diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fd20df7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +files +*.bin +example.webp \ No newline at end of file diff --git a/client/progress.go b/client/progress.go new file mode 100644 index 0000000..4aa8cba --- /dev/null +++ b/client/progress.go @@ -0,0 +1,35 @@ +package main + +import ( + "io" + + "github.com/coalaura/progress" +) + +type ProgressReader struct { + io.Reader + bar *progress.Bar +} + +func NewProgressReader(label string, total int64, reader io.Reader) *ProgressReader { + bar := progress.NewProgressBarWithTheme(label, total, progress.ThemeDots) + + bar.Start() + + return &ProgressReader{ + Reader: reader, + bar: bar, + } +} + +func (pr *ProgressReader) Read(p []byte) (int, error) { + n, err := pr.Reader.Read(p) + + pr.bar.IncrementBy(int64(n)) + + return n, err +} + +func (pr *ProgressReader) Close() { + pr.bar.Stop() +} diff --git a/client/protocol.go b/client/protocol.go index d7ea34b..394882b 100644 --- a/client/protocol.go +++ b/client/protocol.go @@ -4,7 +4,6 @@ import ( "bytes" "crypto/rand" "encoding/base64" - "encoding/json" "errors" "fmt" "io" @@ -14,18 +13,19 @@ import ( "path/filepath" "github.com/coalaura/up/internal" + "github.com/vmihailenco/msgpack/v5" "golang.org/x/crypto/ssh" ) func RequestChallenge(target, public string) (*internal.AuthChallenge, error) { - request, err := json.Marshal(internal.AuthRequest{ + request, err := msgpack.Marshal(internal.AuthRequest{ Public: public, }) if err != nil { return nil, fmt.Errorf("failed to marshal request: %v", err) } - response, err := http.Post(fmt.Sprintf("%s/request", target), "application/json", bytes.NewReader(request)) + response, err := http.Post(fmt.Sprintf("%s/request", target), "application/msgpack", bytes.NewReader(request)) if err != nil { return nil, fmt.Errorf("failed to send request: %v", err) } @@ -38,7 +38,7 @@ func RequestChallenge(target, public string) (*internal.AuthChallenge, error) { var challenge internal.AuthChallenge - if err := json.NewDecoder(response.Body).Decode(&challenge); err != nil { + if err := msgpack.NewDecoder(response.Body).Decode(&challenge); err != nil { return nil, fmt.Errorf("failed to unmarshal response: %v", err) } @@ -56,7 +56,7 @@ func CompleteChallenge(target, public string, private ssh.Signer, challenge *int return nil, fmt.Errorf("failed to sign challenge: %v", err) } - request, err := json.Marshal(internal.AuthResponse{ + request, err := msgpack.Marshal(internal.AuthResponse{ Token: challenge.Token, Public: public, Format: signature.Format, @@ -66,7 +66,7 @@ func CompleteChallenge(target, public string, private ssh.Signer, challenge *int return nil, fmt.Errorf("failed to marshal request: %v", err) } - response, err := http.Post(fmt.Sprintf("%s/complete", target), "application/json", bytes.NewReader(request)) + response, err := http.Post(fmt.Sprintf("%s/complete", target), "application/msgpack", bytes.NewReader(request)) if err != nil { return nil, fmt.Errorf("failed to send request: %v", err) } @@ -79,7 +79,7 @@ func CompleteChallenge(target, public string, private ssh.Signer, challenge *int var result internal.AuthResponse - if err := json.NewDecoder(response.Body).Decode(&result); err != nil { + if err := msgpack.NewDecoder(response.Body).Decode(&result); err != nil { return nil, fmt.Errorf("failed to unmarshal response: %v", err) } @@ -87,22 +87,38 @@ func CompleteChallenge(target, public string, private ssh.Signer, challenge *int } func SendFile(target, token string, file *os.File) error { - var buf bytes.Buffer - - writer := multipart.NewWriter(&buf) - - part, err := writer.CreateFormFile("file", filepath.Base(file.Name())) + stat, err := file.Stat() if err != nil { - return fmt.Errorf("failed to create form file: %v", err) + return fmt.Errorf("failed to stat file: %v", err) } - if _, err := io.Copy(part, file); err != nil { - return fmt.Errorf("failed to copy file: %v", err) - } + pReader, pWriter := io.Pipe() - writer.Close() + writer := multipart.NewWriter(pWriter) - request, err := http.NewRequest("POST", fmt.Sprintf("%s/receive", target), &buf) + go func() { + defer pWriter.Close() + + part, err := writer.CreateFormFile("file", filepath.Base(file.Name())) + if err != nil { + pWriter.CloseWithError(err) + + return + } + + if _, err := io.Copy(part, file); err != nil { + pWriter.CloseWithError(err) + + return + } + + writer.Close() + }() + + reader := NewProgressReader("Uploading", stat.Size(), pReader) + defer reader.Close() + + request, err := http.NewRequest("POST", fmt.Sprintf("%s/receive", target), reader) if err != nil { return fmt.Errorf("failed to create request: %v", err) } @@ -116,6 +132,7 @@ func SendFile(target, token string, file *os.File) error { } response.Body.Close() + reader.Close() if response.StatusCode != http.StatusOK { return errors.New(response.Status) diff --git a/go.mod b/go.mod index 130b156..8895812 100644 --- a/go.mod +++ b/go.mod @@ -4,18 +4,16 @@ go 1.24.2 require ( github.com/coalaura/logger v1.4.5 - github.com/fasthttp/router v1.5.4 + github.com/coalaura/progress v1.1.6 + github.com/go-chi/chi/v5 v5.2.2 github.com/urfave/cli/v3 v3.3.8 - github.com/valyala/fasthttp v1.62.0 + github.com/vmihailenco/msgpack/v5 v5.4.1 golang.org/x/crypto v0.39.0 ) require ( - github.com/andybalholm/brotli v1.1.1 // indirect github.com/gookit/color v1.5.4 // indirect - github.com/klauspost/compress v1.18.0 // indirect - github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287 // indirect - github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/sys v0.33.0 // indirect golang.org/x/term v0.32.0 // indirect diff --git a/go.sum b/go.sum index 2ebed83..7fa8934 100644 --- a/go.sum +++ b/go.sum @@ -1,31 +1,25 @@ -github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= -github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/coalaura/logger v1.4.5 h1:xXazOab4qXaltUbD4TrQdSs2TtLB+k6t0t6y/M8LR3Q= github.com/coalaura/logger v1.4.5/go.mod h1:3HCYCWmsWmYW175e2/fZL9BWjJutr2W+7adeh1BPHkg= +github.com/coalaura/progress v1.1.6 h1:SOeuvH3M/sUDezyjCZwBaoMWKyPVcBzhxrL1qZqtV2w= +github.com/coalaura/progress v1.1.6/go.mod h1:2t8PFWZG8m+c6x8fBrGyJajclxVDLytNctoXVYyFPbc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fasthttp/router v1.5.4 h1:oxdThbBwQgsDIYZ3wR1IavsNl6ZS9WdjKukeMikOnC8= -github.com/fasthttp/router v1.5.4/go.mod h1:3/hysWq6cky7dTfzaaEPZGdptwjwx0qzTgFCKEWRjgc= +github.com/go-chi/chi/v5 v5.2.2 h1:CMwsvRVTbXVytCk1Wd72Zy1LAsAh9GxMmSNWLHCG618= +github.com/go-chi/chi/v5 v5.2.2/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w= -github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= -github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287 h1:qIQ0tWF9vxGtkJa24bR+2i53WBCz1nW/Pc47oVYauC4= -github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/urfave/cli/v3 v3.3.8 h1:BzolUExliMdet9NlJ/u4m5vHSotJ3PzEqSAZ1oPMa/E= github.com/urfave/cli/v3 v3.3.8/go.mod h1:FJSKtM/9AiiTOJL4fJ6TbMUkxBXn7GO9guZqoZtpYpo= -github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.62.0 h1:8dKRBX/y2rCzyc6903Zu1+3qN0H/d2MsxPPmVNamiH0= -github.com/valyala/fasthttp v1.62.0/go.mod h1:FCINgr4GKdKqV8Q0xv8b+UxPV+H/O5nNFo3D+r54Htg= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= -github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= -github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E= diff --git a/server/main.go b/server/main.go index d02b4c8..2f89f76 100644 --- a/server/main.go +++ b/server/main.go @@ -1,11 +1,11 @@ package main import ( + "net/http" "sync" "github.com/coalaura/logger" - "github.com/fasthttp/router" - "github.com/valyala/fasthttp" + "github.com/go-chi/chi/v5" ) var ( @@ -21,18 +21,18 @@ func main() { authorized, err := LoadAuthorizedKeys() log.MustPanic(err) - r := router.New() + r := chi.NewRouter() - r.POST("/request", func(ctx *fasthttp.RequestCtx) { - HandleChallengeRequest(ctx, authorized) + r.Post("/request", func(w http.ResponseWriter, r *http.Request) { + HandleChallengeRequest(w, r, authorized) }) - r.POST("/complete", func(ctx *fasthttp.RequestCtx) { - HandleCompleteRequest(ctx, authorized) + r.Post("/complete", func(w http.ResponseWriter, r *http.Request) { + HandleCompleteRequest(w, r, authorized) }) - r.POST("/receive", HandleReceiveRequest) + r.Post("/receive", HandleReceiveRequest) log.Println("Listening on :7966") - fasthttp.ListenAndServe(":7966", r.Handler) + http.ListenAndServe(":7966", r) } diff --git a/server/protocol.go b/server/protocol.go index eca04dc..1947a41 100644 --- a/server/protocol.go +++ b/server/protocol.go @@ -3,15 +3,15 @@ package main import ( "bytes" "encoding/base64" - "encoding/json" "errors" "io" + "net/http" "os" "path/filepath" "time" "github.com/coalaura/up/internal" - "github.com/valyala/fasthttp" + "github.com/vmihailenco/msgpack/v5" "golang.org/x/crypto/ssh" ) @@ -31,11 +31,11 @@ func IsSignatureFormatValid(format string) bool { return SignatureFormats[format] } -func HandleChallengeRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh.PublicKey) { +func HandleChallengeRequest(w http.ResponseWriter, r *http.Request, authorized map[string]ssh.PublicKey) { var request internal.AuthRequest - if err := json.Unmarshal(ctx.PostBody(), &request); err != nil { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + if err := msgpack.NewDecoder(r.Body).Decode(&request); err != nil { + w.WriteHeader(http.StatusBadRequest) log.Warning("request: failed to decode request") log.WarningE(err) @@ -45,7 +45,7 @@ func HandleChallengeRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh. public, err := DecodeAndAuthorizePublicKey(request.Public, authorized) if err != nil { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) log.Warning("request: failed to parse/authorize public key") log.WarningE(err) @@ -55,7 +55,7 @@ func HandleChallengeRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh. challenge, raw, err := internal.FreshChallenge() if err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) log.Warning("request: failed to generate challenge") log.WarningE(err) @@ -71,15 +71,15 @@ func HandleChallengeRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh. log.Println("new auth request") - ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(challenge) + w.Header().Set("Content-Type", "application/msgpack") + msgpack.NewEncoder(w).Encode(challenge) } -func HandleCompleteRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh.PublicKey) { +func HandleCompleteRequest(w http.ResponseWriter, r *http.Request, authorized map[string]ssh.PublicKey) { var response internal.AuthResponse - if err := json.Unmarshal(ctx.PostBody(), &response); err != nil { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + if err := msgpack.NewDecoder(r.Body).Decode(&response); err != nil { + w.WriteHeader(http.StatusBadRequest) log.Warning("complete: failed to decode response") log.WarningE(err) @@ -89,7 +89,7 @@ func HandleCompleteRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh.P public, err := DecodeAndAuthorizePublicKey(response.Public, authorized) if err != nil { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) log.Warning("complete: failed to parse/authorize public key") log.WarningE(err) @@ -99,7 +99,7 @@ func HandleCompleteRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh.P entry, ok := challenges.LoadAndDelete(response.Token) if !ok { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) log.Warning("complete: invalid challenge token") @@ -109,7 +109,7 @@ func HandleCompleteRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh.P challenge := entry.(internal.ChallengeEntry) if time.Now().After(challenge.Expires) { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) log.Warning("complete: challenge expired") @@ -120,7 +120,7 @@ func HandleCompleteRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh.P publicB := challenge.PublicKey.Marshal() if !bytes.Equal(publicA, publicB) { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) log.Warning("complete: incorrect public key") @@ -128,7 +128,7 @@ func HandleCompleteRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh.P } if !IsSignatureFormatValid(response.Format) { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) log.Warning("complete: unsupported signature format") @@ -137,7 +137,7 @@ func HandleCompleteRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh.P signature, err := base64.StdEncoding.DecodeString(response.Signature) if err != nil { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) log.Warning("complete: failed to decode signature") log.WarningE(err) @@ -151,7 +151,7 @@ func HandleCompleteRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh.P } if err = public.Verify(challenge.Challenge, sig); err != nil { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) log.Warning("complete: failed to verify signature") log.WarningE(err) @@ -161,7 +161,7 @@ func HandleCompleteRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh.P token, err := RandomToken(64) if err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) log.Warning("complete: failed to create token") log.WarningE(err) @@ -176,16 +176,16 @@ func HandleCompleteRequest(ctx *fasthttp.RequestCtx, authorized map[string]ssh.P log.Println("auth completed") - ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(internal.AuthResult{ + w.Header().Set("Content-Type", "application/msgpack") + msgpack.NewEncoder(w).Encode(internal.AuthResult{ Token: token, }) } -func HandleReceiveRequest(ctx *fasthttp.RequestCtx) { - token := string(ctx.Request.Header.Peek("Authorization")) +func HandleReceiveRequest(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("Authorization") if token == "" { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) log.Warning("receive: missing token") @@ -194,7 +194,7 @@ func HandleReceiveRequest(ctx *fasthttp.RequestCtx) { entry, ok := sessions.LoadAndDelete(token) if !ok { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) log.Warning("receive: invalid token") @@ -204,45 +204,43 @@ func HandleReceiveRequest(ctx *fasthttp.RequestCtx) { session := entry.(internal.SessionEntry) if time.Now().After(session.Expires) { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) log.Warning("receive: session expired") return } - form, err := ctx.MultipartForm() + reader, err := r.MultipartReader() if err != nil { - ctx.SetStatusCode(fasthttp.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) - log.Warning("receive: failed to parse multipart form") + log.Warning("receive: failed to open multipart form") log.WarningE(err) return } - files := form.File["file"] - if len(files) == 0 { - ctx.SetStatusCode(fasthttp.StatusBadRequest) - - log.Warning("receive: no files received") - - return - } - - header := files[0] - name := filepath.Base(header.Filename) - - source, err := header.Open() + part, err := reader.NextPart() if err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) + w.WriteHeader(http.StatusBadRequest) - log.Warning("receive: failed to open sent file") + log.Warning("receive: failed to read multipart form") + log.WarningE(err) return } - defer source.Close() + if part.FormName() != "file" { + w.WriteHeader(http.StatusBadRequest) + + log.Warning("receive: invalid multipart part") + log.WarningE(err) + + return + } + + name := filepath.Base(part.FileName()) if _, err := os.Stat("files"); os.IsNotExist(err) { os.Mkdir("files", 0755) @@ -250,7 +248,7 @@ func HandleReceiveRequest(ctx *fasthttp.RequestCtx) { target, err := os.OpenFile(filepath.Join("files", name), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) if err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) log.Warning("receive: failed to open target file") @@ -259,15 +257,15 @@ func HandleReceiveRequest(ctx *fasthttp.RequestCtx) { defer target.Close() - if _, err := io.Copy(target, source); err != nil { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) + if _, err := io.Copy(target, part); err != nil { + w.WriteHeader(http.StatusInternalServerError) log.Warning("receive: failed to copy sent file") return } - ctx.SetStatusCode(fasthttp.StatusOK) + w.WriteHeader(http.StatusOK) } func DecodeAndAuthorizePublicKey(public string, authorized map[string]ssh.PublicKey) (ssh.PublicKey, error) { diff --git a/test.cmd b/test.cmd index c2337db..aa5c325 100644 --- a/test.cmd +++ b/test.cmd @@ -1,3 +1,3 @@ @echo off -go run .\client --key example.key -f example.webp -t localhost:7966 \ No newline at end of file +go run .\client --key example.key -f test.bin -t localhost:7966 \ No newline at end of file