1
0
mirror of https://github.com/coalaura/up.git synced 2025-07-18 21:53:23 +00:00

validate server key

This commit is contained in:
Laura
2025-06-20 17:10:03 +02:00
parent 3f17910502
commit 80b9989dd0
9 changed files with 300 additions and 43 deletions

178
client/certificates.go Normal file
View File

@ -0,0 +1,178 @@
package main
import (
"bytes"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"errors"
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
)
type PinnedCertificate struct {
Name string
Fingerprint string
}
type CertificateStore struct {
path string
pinned []PinnedCertificate
mx sync.RWMutex
}
func (cs *CertificateStore) IsPinned(name, fingerprint string) bool {
cs.mx.RLock()
defer cs.mx.RUnlock()
if len(cs.pinned) == 0 {
return false
}
for _, pin := range cs.pinned {
if pin.Fingerprint == fingerprint && pin.Name == name {
return true
}
}
return false
}
func (cs *CertificateStore) Pin(name, fingerprint string) error {
cs.mx.Lock()
defer cs.mx.Unlock()
pin := PinnedCertificate{
Name: name,
Fingerprint: fingerprint,
}
file, err := os.OpenFile(cs.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return err
}
defer file.Close()
if _, err = file.WriteString(fmt.Sprintf("%s %s\n", name, fingerprint)); err != nil {
return err
}
cs.pinned = append(cs.pinned, pin)
return nil
}
func LoadCertificateStore() (*CertificateStore, error) {
path, err := PinnedCertificatesPath()
if err != nil {
return nil, err
}
store := &CertificateStore{
path: path,
}
if _, err := os.Stat(path); os.IsNotExist(err) {
return store, nil
}
contents, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var index int
for line := range bytes.SplitSeq(contents, []byte("\n")) {
index++
if len(line) == 0 {
continue
}
index := bytes.Index(line, []byte(" "))
if index == -1 {
return nil, fmt.Errorf("Invalid pinned certificate on line %d\n", index)
}
name := line[:index]
fingerprint := line[index:]
if len(fingerprint) < 64 {
return nil, fmt.Errorf("Invalid fingerprint on line %d\n", index)
}
store.pinned = append(store.pinned, PinnedCertificate{
Name: string(name),
Fingerprint: string(fingerprint),
})
}
return store, nil
}
func PinnedCertificatesPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".up"), nil
}
func CertificateFingerprint(certificate *x509.Certificate) string {
sum := sha256.Sum256(certificate.Raw)
algo := strings.ToLower(certificate.PublicKeyAlgorithm.String())
return fmt.Sprintf("%s-%s", algo, hex.EncodeToString(sum[:]))
}
func NewPinnedClient(store *CertificateStore) *http.Client {
config := &tls.Config{
InsecureSkipVerify: true,
}
config.VerifyConnection = func(cs tls.ConnectionState) error {
if len(cs.PeerCertificates) == 0 {
return errors.New("missing certificate")
}
certificate := cs.PeerCertificates[0]
if certificate.Subject.CommonName != "up" {
return errors.New("invalid certificate subject")
}
name := cs.ServerName
fingerprint := CertificateFingerprint(certificate)
if store.IsPinned(name, fingerprint) {
return nil
}
log.Printf("Server fingerprint (%s): %s\n", name, fingerprint)
log.Print("Accept? [y/N]: ")
var confirm string
fmt.Scanln(&confirm)
if strings.ToLower(strings.TrimSpace(confirm)) != "y" {
return errors.New("certificate rejected")
}
return store.Pin(name, fingerprint)
}
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: config,
},
}
}

View File

@ -5,14 +5,21 @@ import (
"encoding/base64"
"errors"
"fmt"
"log"
"os"
"path/filepath"
"github.com/coalaura/logger"
"github.com/urfave/cli/v3"
)
var Version = "dev"
var (
Version = "dev"
log = logger.New().DetectTerminal().WithOptions(logger.Options{
NoTime: true,
NoLevel: true,
})
)
func main() {
app := &cli.Command{
@ -42,14 +49,20 @@ func main() {
Suggest: true,
}
if err := app.Run(context.Background(), os.Args); err != nil {
fmt.Printf("fatal: %v\n", err)
os.Exit(1)
}
err := app.Run(context.Background(), os.Args)
log.MustPanic(err)
}
func run(_ context.Context, cmd *cli.Command) error {
log.Println("loading certificate store")
store, err := LoadCertificateStore()
if err != nil {
return fmt.Errorf("failed to load certificate store: %v", err)
}
client := NewPinnedClient(store)
path := cmd.String("key")
if path == "" {
return errors.New("missing private key")
@ -60,7 +73,7 @@ func run(_ context.Context, cmd *cli.Command) error {
return fmt.Errorf("failed to get key path: %v", err)
}
log.Printf("using key %s", kPath)
log.Printf("using key %s\n", kPath)
path = cmd.String("file")
if path == "" {
@ -72,7 +85,7 @@ func run(_ context.Context, cmd *cli.Command) error {
return fmt.Errorf("failed to get file path: %v", err)
}
log.Printf("using file %s", fPath)
log.Printf("using file %s\n", fPath)
file, err := os.OpenFile(fPath, os.O_RDONLY, 0)
if err != nil {
@ -88,7 +101,7 @@ func run(_ context.Context, cmd *cli.Command) error {
target = fmt.Sprintf("https://%s", target)
log.Printf("using target %s", target)
log.Printf("using target %s\n", target)
log.Printf("loading key")
@ -101,19 +114,17 @@ func run(_ context.Context, cmd *cli.Command) error {
log.Println("requesting challenge")
challenge, err := RequestChallenge(target, public)
challenge, err := RequestChallenge(client, target, public)
if err != nil {
return err
}
log.Println("completing challenge")
response, err := CompleteChallenge(target, public, private, challenge)
response, err := CompleteChallenge(client, target, public, private, challenge)
if err != nil {
return err
}
log.Println("uploading file")
return SendFile(target, response.Token, file)
return SendFile(client, target, response.Token, file)
}

View File

@ -2,34 +2,30 @@ package main
import (
"io"
"github.com/coalaura/progress"
)
type ProgressReader struct {
io.Reader
bar *progress.Bar
label string
total int64
read int64
}
func NewProgressReader(label string, total int64, reader io.Reader) *ProgressReader {
bar := progress.NewProgressBarWithTheme(label, total, progress.ThemeDots)
bar.Start()
return &ProgressReader{
Reader: reader,
bar: bar,
label: label,
total: total,
}
}
func (pr *ProgressReader) Read(p []byte) (int, error) {
n, err := pr.Reader.Read(p)
pr.bar.IncrementBy(int64(n))
pr.read += int64(n)
percentage := float64(pr.read) / float64(pr.total) * 100
log.Printf("\r%s %.1f%%", pr.label, percentage)
return n, err
}
func (pr *ProgressReader) Close() {
pr.bar.Stop()
}

View File

@ -17,7 +17,7 @@ import (
"golang.org/x/crypto/ssh"
)
func RequestChallenge(target, public string) (*internal.AuthChallenge, error) {
func RequestChallenge(client *http.Client, target, public string) (*internal.AuthChallenge, error) {
request, err := msgpack.Marshal(internal.AuthRequest{
Public: public,
})
@ -25,7 +25,7 @@ func RequestChallenge(target, public string) (*internal.AuthChallenge, error) {
return nil, fmt.Errorf("failed to marshal request: %v", err)
}
response, err := http.Post(fmt.Sprintf("%s/request", target), "application/msgpack", bytes.NewReader(request))
response, err := client.Post(fmt.Sprintf("%s/request", target), "application/msgpack", bytes.NewReader(request))
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
@ -45,7 +45,7 @@ func RequestChallenge(target, public string) (*internal.AuthChallenge, error) {
return &challenge, nil
}
func CompleteChallenge(target, public string, private ssh.Signer, challenge *internal.AuthChallenge) (*internal.AuthResponse, error) {
func CompleteChallenge(client *http.Client, target, public string, private ssh.Signer, challenge *internal.AuthChallenge) (*internal.AuthResponse, error) {
rawChallenge, err := base64.StdEncoding.DecodeString(challenge.Challenge)
if err != nil {
return nil, fmt.Errorf("failed to decode challenge: %v", err)
@ -66,7 +66,7 @@ func CompleteChallenge(target, public string, private ssh.Signer, challenge *int
return nil, fmt.Errorf("failed to marshal request: %v", err)
}
response, err := http.Post(fmt.Sprintf("%s/complete", target), "application/msgpack", bytes.NewReader(request))
response, err := client.Post(fmt.Sprintf("%s/complete", target), "application/msgpack", bytes.NewReader(request))
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
@ -86,7 +86,7 @@ func CompleteChallenge(target, public string, private ssh.Signer, challenge *int
return &result, nil
}
func SendFile(target, token string, file *os.File) error {
func SendFile(client *http.Client, target, token string, file *os.File) error {
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat file: %v", err)
@ -115,8 +115,7 @@ func SendFile(target, token string, file *os.File) error {
writer.Close()
}()
reader := NewProgressReader("Uploading", stat.Size(), pReader)
defer reader.Close()
reader := NewProgressReader("uploading file", stat.Size(), pReader)
request, err := http.NewRequest("POST", fmt.Sprintf("%s/receive", target), reader)
if err != nil {
@ -126,13 +125,12 @@ func SendFile(target, token string, file *os.File) error {
request.Header.Set("Content-Type", writer.FormDataContentType())
request.Header.Set("Authorization", token)
response, err := http.DefaultClient.Do(request)
response, err := client.Do(request)
if err != nil {
return fmt.Errorf("failed to send request: %v", err)
}
response.Body.Close()
reader.Close()
if response.StatusCode != http.StatusOK {
return errors.New(response.Status)