1
0
mirror of https://github.com/coalaura/up.git synced 2025-07-17 21:44:35 +00:00
Files
up/client/certificates.go
2025-06-21 00:09:08 +02:00

220 lines
4.1 KiB
Go

package client
import (
"bytes"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"errors"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
type PinnedCertificate struct {
Name string
Fingerprint string
}
type CertificateStore struct {
path string
pinned []PinnedCertificate
mx sync.RWMutex
}
func (cs *CertificateStore) IsPinned(name, fingerprint string) bool {
cs.mx.RLock()
defer cs.mx.RUnlock()
if len(cs.pinned) == 0 {
return false
}
for _, pin := range cs.pinned {
if pin.Fingerprint == fingerprint && pin.Name == name {
return true
}
}
return false
}
func (cs *CertificateStore) Pin(name, fingerprint string) error {
cs.mx.Lock()
defer cs.mx.Unlock()
pin := PinnedCertificate{
Name: name,
Fingerprint: fingerprint,
}
file, err := os.OpenFile(cs.path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return err
}
defer file.Close()
if _, err = file.WriteString(fmt.Sprintf("%s %s\n", name, fingerprint)); err != nil {
return err
}
cs.pinned = append(cs.pinned, pin)
return nil
}
func LoadCertificateStore() (*CertificateStore, error) {
path, err := PinnedCertificatesPath()
if err != nil {
return nil, err
}
store := &CertificateStore{
path: path,
}
if _, err := os.Stat(path); os.IsNotExist(err) {
return store, nil
}
contents, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var index int
for line := range bytes.SplitSeq(contents, []byte("\n")) {
index++
if len(line) == 0 {
continue
}
fields := bytes.Fields(line)
if len(fields) != 2 {
return nil, fmt.Errorf("Invalid pinned certificate on line %d\n", index)
}
name := bytes.ToLower(fields[0])
fingerprint := bytes.ToLower(fields[1])
if len(fingerprint) < 64 {
return nil, fmt.Errorf("Invalid fingerprint on line %d\n", index)
}
store.pinned = append(store.pinned, PinnedCertificate{
Name: string(name),
Fingerprint: string(fingerprint),
})
}
return store, nil
}
func PinnedCertificatesPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".up"), nil
}
func CertificateFingerprint(certificate *x509.Certificate) string {
sum := sha256.Sum256(certificate.Raw)
algo := strings.ToLower(certificate.PublicKeyAlgorithm.String())
return fmt.Sprintf("%s-%s", algo, hex.EncodeToString(sum[:]))
}
func PreFetchServerCertificate(store *CertificateStore, addr string) error {
conn, err := tls.DialWithDialer(&net.Dialer{
Timeout: 5 * time.Second,
}, "tcp", addr, &tls.Config{
InsecureSkipVerify: true,
})
if err != nil {
return err
}
defer conn.Close()
state := conn.ConnectionState()
if len(state.PeerCertificates) == 0 {
return fmt.Errorf("no peer certificates")
}
certificate := state.PeerCertificates[0]
if certificate.Subject.CommonName != "up" {
return errors.New("invalid certificate subject")
}
name := state.ServerName
fingerprint := CertificateFingerprint(certificate)
if store.IsPinned(name, fingerprint) {
return nil
}
log.Printf("Server fingerprint for %s: %s\n", name, fingerprint)
log.Print("Accept? [y/N]: ")
var confirm string
fmt.Scanln(&confirm)
if strings.ToLower(strings.TrimSpace(confirm)) != "y" {
return errors.New("certificate rejected")
}
return store.Pin(name, fingerprint)
}
func NewPinnedClient(store *CertificateStore) *http.Client {
config := &tls.Config{
InsecureSkipVerify: true,
}
config.VerifyConnection = func(cs tls.ConnectionState) error {
if len(cs.PeerCertificates) == 0 {
return errors.New("missing certificate")
}
certificate := cs.PeerCertificates[0]
if certificate.Subject.CommonName != "up" {
return errors.New("invalid certificate subject")
}
name := cs.ServerName
fingerprint := CertificateFingerprint(certificate)
if !store.IsPinned(name, fingerprint) {
return errors.New("unknown certificate")
}
return nil
}
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: config,
Dial: (&net.Dialer{
Timeout: 5 * time.Second,
}).Dial,
TLSHandshakeTimeout: 5 * time.Second,
IdleConnTimeout: 10 * time.Second,
},
}
}