diff --git a/.gitignore b/.gitignore index fd20df7..757e400 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ files +example.webp *.bin -example.webp \ No newline at end of file +*.pem \ No newline at end of file diff --git a/client/certificates.go b/client/certificates.go new file mode 100644 index 0000000..85d8f2d --- /dev/null +++ b/client/certificates.go @@ -0,0 +1,178 @@ +package main + +import ( + "bytes" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "errors" + "fmt" + "net/http" + "os" + "path/filepath" + "strings" + "sync" +) + +type PinnedCertificate struct { + Name string + Fingerprint string +} + +type CertificateStore struct { + path string + pinned []PinnedCertificate + mx sync.RWMutex +} + +func (cs *CertificateStore) IsPinned(name, fingerprint string) bool { + cs.mx.RLock() + defer cs.mx.RUnlock() + + if len(cs.pinned) == 0 { + return false + } + + for _, pin := range cs.pinned { + if pin.Fingerprint == fingerprint && pin.Name == name { + return true + } + } + + return false +} + +func (cs *CertificateStore) Pin(name, fingerprint string) error { + cs.mx.Lock() + defer cs.mx.Unlock() + + pin := PinnedCertificate{ + Name: name, + Fingerprint: fingerprint, + } + + file, err := os.OpenFile(cs.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return err + } + + defer file.Close() + + if _, err = file.WriteString(fmt.Sprintf("%s %s\n", name, fingerprint)); err != nil { + return err + } + + cs.pinned = append(cs.pinned, pin) + + return nil +} + +func LoadCertificateStore() (*CertificateStore, error) { + path, err := PinnedCertificatesPath() + if err != nil { + return nil, err + } + + store := &CertificateStore{ + path: path, + } + + if _, err := os.Stat(path); os.IsNotExist(err) { + return store, nil + } + + contents, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var index int + + for line := range bytes.SplitSeq(contents, []byte("\n")) { + index++ + + if len(line) == 0 { + continue + } + + index := bytes.Index(line, []byte(" ")) + if index == -1 { + return nil, fmt.Errorf("Invalid pinned certificate on line %d\n", index) + } + + name := line[:index] + fingerprint := line[index:] + + if len(fingerprint) < 64 { + return nil, fmt.Errorf("Invalid fingerprint on line %d\n", index) + } + + store.pinned = append(store.pinned, PinnedCertificate{ + Name: string(name), + Fingerprint: string(fingerprint), + }) + } + + return store, nil +} + +func PinnedCertificatesPath() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + return filepath.Join(home, ".up"), nil +} + +func CertificateFingerprint(certificate *x509.Certificate) string { + sum := sha256.Sum256(certificate.Raw) + algo := strings.ToLower(certificate.PublicKeyAlgorithm.String()) + + return fmt.Sprintf("%s-%s", algo, hex.EncodeToString(sum[:])) +} + +func NewPinnedClient(store *CertificateStore) *http.Client { + config := &tls.Config{ + InsecureSkipVerify: true, + } + + config.VerifyConnection = func(cs tls.ConnectionState) error { + if len(cs.PeerCertificates) == 0 { + return errors.New("missing certificate") + } + + certificate := cs.PeerCertificates[0] + + if certificate.Subject.CommonName != "up" { + return errors.New("invalid certificate subject") + } + + name := cs.ServerName + fingerprint := CertificateFingerprint(certificate) + + if store.IsPinned(name, fingerprint) { + return nil + } + + log.Printf("Server fingerprint (%s): %s\n", name, fingerprint) + log.Print("Accept? [y/N]: ") + + var confirm string + + fmt.Scanln(&confirm) + + if strings.ToLower(strings.TrimSpace(confirm)) != "y" { + return errors.New("certificate rejected") + } + + return store.Pin(name, fingerprint) + } + + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: config, + }, + } +} diff --git a/client/main.go b/client/main.go index b921014..380b4e3 100644 --- a/client/main.go +++ b/client/main.go @@ -5,14 +5,21 @@ import ( "encoding/base64" "errors" "fmt" - "log" "os" "path/filepath" + "github.com/coalaura/logger" "github.com/urfave/cli/v3" ) -var Version = "dev" +var ( + Version = "dev" + + log = logger.New().DetectTerminal().WithOptions(logger.Options{ + NoTime: true, + NoLevel: true, + }) +) func main() { app := &cli.Command{ @@ -42,14 +49,20 @@ func main() { Suggest: true, } - if err := app.Run(context.Background(), os.Args); err != nil { - fmt.Printf("fatal: %v\n", err) - - os.Exit(1) - } + err := app.Run(context.Background(), os.Args) + log.MustPanic(err) } func run(_ context.Context, cmd *cli.Command) error { + log.Println("loading certificate store") + + store, err := LoadCertificateStore() + if err != nil { + return fmt.Errorf("failed to load certificate store: %v", err) + } + + client := NewPinnedClient(store) + path := cmd.String("key") if path == "" { return errors.New("missing private key") @@ -60,7 +73,7 @@ func run(_ context.Context, cmd *cli.Command) error { return fmt.Errorf("failed to get key path: %v", err) } - log.Printf("using key %s", kPath) + log.Printf("using key %s\n", kPath) path = cmd.String("file") if path == "" { @@ -72,7 +85,7 @@ func run(_ context.Context, cmd *cli.Command) error { return fmt.Errorf("failed to get file path: %v", err) } - log.Printf("using file %s", fPath) + log.Printf("using file %s\n", fPath) file, err := os.OpenFile(fPath, os.O_RDONLY, 0) if err != nil { @@ -88,7 +101,7 @@ func run(_ context.Context, cmd *cli.Command) error { target = fmt.Sprintf("https://%s", target) - log.Printf("using target %s", target) + log.Printf("using target %s\n", target) log.Printf("loading key") @@ -101,19 +114,17 @@ func run(_ context.Context, cmd *cli.Command) error { log.Println("requesting challenge") - challenge, err := RequestChallenge(target, public) + challenge, err := RequestChallenge(client, target, public) if err != nil { return err } log.Println("completing challenge") - response, err := CompleteChallenge(target, public, private, challenge) + response, err := CompleteChallenge(client, target, public, private, challenge) if err != nil { return err } - log.Println("uploading file") - - return SendFile(target, response.Token, file) + return SendFile(client, target, response.Token, file) } diff --git a/client/progress.go b/client/progress.go index 4aa8cba..d262471 100644 --- a/client/progress.go +++ b/client/progress.go @@ -2,34 +2,30 @@ package main import ( "io" - - "github.com/coalaura/progress" ) type ProgressReader struct { io.Reader - bar *progress.Bar + label string + total int64 + read int64 } func NewProgressReader(label string, total int64, reader io.Reader) *ProgressReader { - bar := progress.NewProgressBarWithTheme(label, total, progress.ThemeDots) - - bar.Start() - return &ProgressReader{ Reader: reader, - bar: bar, + label: label, + total: total, } } func (pr *ProgressReader) Read(p []byte) (int, error) { n, err := pr.Reader.Read(p) - pr.bar.IncrementBy(int64(n)) + pr.read += int64(n) + + percentage := float64(pr.read) / float64(pr.total) * 100 + log.Printf("\r%s %.1f%%", pr.label, percentage) return n, err } - -func (pr *ProgressReader) Close() { - pr.bar.Stop() -} diff --git a/client/protocol.go b/client/protocol.go index 394882b..95ac750 100644 --- a/client/protocol.go +++ b/client/protocol.go @@ -17,7 +17,7 @@ import ( "golang.org/x/crypto/ssh" ) -func RequestChallenge(target, public string) (*internal.AuthChallenge, error) { +func RequestChallenge(client *http.Client, target, public string) (*internal.AuthChallenge, error) { request, err := msgpack.Marshal(internal.AuthRequest{ Public: public, }) @@ -25,7 +25,7 @@ func RequestChallenge(target, public string) (*internal.AuthChallenge, error) { return nil, fmt.Errorf("failed to marshal request: %v", err) } - response, err := http.Post(fmt.Sprintf("%s/request", target), "application/msgpack", bytes.NewReader(request)) + response, err := client.Post(fmt.Sprintf("%s/request", target), "application/msgpack", bytes.NewReader(request)) if err != nil { return nil, fmt.Errorf("failed to send request: %v", err) } @@ -45,7 +45,7 @@ func RequestChallenge(target, public string) (*internal.AuthChallenge, error) { return &challenge, nil } -func CompleteChallenge(target, public string, private ssh.Signer, challenge *internal.AuthChallenge) (*internal.AuthResponse, error) { +func CompleteChallenge(client *http.Client, target, public string, private ssh.Signer, challenge *internal.AuthChallenge) (*internal.AuthResponse, error) { rawChallenge, err := base64.StdEncoding.DecodeString(challenge.Challenge) if err != nil { return nil, fmt.Errorf("failed to decode challenge: %v", err) @@ -66,7 +66,7 @@ func CompleteChallenge(target, public string, private ssh.Signer, challenge *int return nil, fmt.Errorf("failed to marshal request: %v", err) } - response, err := http.Post(fmt.Sprintf("%s/complete", target), "application/msgpack", bytes.NewReader(request)) + response, err := client.Post(fmt.Sprintf("%s/complete", target), "application/msgpack", bytes.NewReader(request)) if err != nil { return nil, fmt.Errorf("failed to send request: %v", err) } @@ -86,7 +86,7 @@ func CompleteChallenge(target, public string, private ssh.Signer, challenge *int return &result, nil } -func SendFile(target, token string, file *os.File) error { +func SendFile(client *http.Client, target, token string, file *os.File) error { stat, err := file.Stat() if err != nil { return fmt.Errorf("failed to stat file: %v", err) @@ -115,8 +115,7 @@ func SendFile(target, token string, file *os.File) error { writer.Close() }() - reader := NewProgressReader("Uploading", stat.Size(), pReader) - defer reader.Close() + reader := NewProgressReader("uploading file", stat.Size(), pReader) request, err := http.NewRequest("POST", fmt.Sprintf("%s/receive", target), reader) if err != nil { @@ -126,13 +125,12 @@ func SendFile(target, token string, file *os.File) error { request.Header.Set("Content-Type", writer.FormDataContentType()) request.Header.Set("Authorization", token) - response, err := http.DefaultClient.Do(request) + response, err := client.Do(request) if err != nil { return fmt.Errorf("failed to send request: %v", err) } response.Body.Close() - reader.Close() if response.StatusCode != http.StatusOK { return errors.New(response.Status) diff --git a/go.mod b/go.mod index 8895812..16ca091 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.24.2 require ( github.com/coalaura/logger v1.4.5 - github.com/coalaura/progress v1.1.6 github.com/go-chi/chi/v5 v5.2.2 github.com/urfave/cli/v3 v3.3.8 github.com/vmihailenco/msgpack/v5 v5.4.1 diff --git a/go.sum b/go.sum index 7fa8934..bfcfdec 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/coalaura/logger v1.4.5 h1:xXazOab4qXaltUbD4TrQdSs2TtLB+k6t0t6y/M8LR3Q= github.com/coalaura/logger v1.4.5/go.mod h1:3HCYCWmsWmYW175e2/fZL9BWjJutr2W+7adeh1BPHkg= -github.com/coalaura/progress v1.1.6 h1:SOeuvH3M/sUDezyjCZwBaoMWKyPVcBzhxrL1qZqtV2w= -github.com/coalaura/progress v1.1.6/go.mod h1:2t8PFWZG8m+c6x8fBrGyJajclxVDLytNctoXVYyFPbc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-chi/chi/v5 v5.2.2 h1:CMwsvRVTbXVytCk1Wd72Zy1LAsAh9GxMmSNWLHCG618= diff --git a/server/certificate.go b/server/certificate.go new file mode 100644 index 0000000..70bd5a4 --- /dev/null +++ b/server/certificate.go @@ -0,0 +1,73 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "time" +) + +func EnsureCertificate(certPath, keyPath string) error { + if _, err := os.Stat(certPath); err == nil { + if _, err = os.Stat(keyPath); err == nil { + return nil + } + } + + private, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + + now := time.Now() + + serial, err := rand.Int(rand.Reader, big.NewInt(1<<62)) + if err != nil { + return err + } + + template := x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: "up"}, + NotBefore: now, + NotAfter: now.AddDate(1, 0, 0), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + + certificate, err := x509.CreateCertificate(rand.Reader, &template, &template, &private.PublicKey, private) + if err != nil { + return err + } + + cFile, err := os.OpenFile(certPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + if err != nil { + return err + } + + defer cFile.Close() + + err = pem.Encode(cFile, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certificate, + }) + if err != nil { + return err + } + + kFile, err := os.OpenFile(keyPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + if err != nil { + return err + } + + defer kFile.Close() + + return pem.Encode(kFile, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(private), + }) +} diff --git a/server/main.go b/server/main.go index 2f89f76..83fa37f 100644 --- a/server/main.go +++ b/server/main.go @@ -9,7 +9,7 @@ import ( ) var ( - log = logger.New().WithOptions(logger.Options{ + log = logger.New().DetectTerminal().WithOptions(logger.Options{ NoLevel: true, }) @@ -21,6 +21,9 @@ func main() { authorized, err := LoadAuthorizedKeys() log.MustPanic(err) + err = EnsureCertificate("cert.pem", "key.pem") + log.MustPanic(err) + r := chi.NewRouter() r.Post("/request", func(w http.ResponseWriter, r *http.Request) { @@ -34,5 +37,5 @@ func main() { r.Post("/receive", HandleReceiveRequest) log.Println("Listening on :7966") - http.ListenAndServe(":7966", r) + http.ListenAndServeTLS(":7966", "cert.pem", "key.pem", r) }