diff --git a/client/certificates.go b/client/certificates.go index 1981ccc..cb28bf8 100644 --- a/client/certificates.go +++ b/client/certificates.go @@ -8,11 +8,13 @@ import ( "encoding/hex" "errors" "fmt" + "net" "net/http" "os" "path/filepath" "strings" "sync" + "time" ) type PinnedCertificate struct { @@ -96,13 +98,13 @@ func LoadCertificateStore() (*CertificateStore, error) { continue } - index := bytes.Index(line, []byte(" ")) - if index == -1 { + fields := bytes.Fields(line) + if len(fields) != 2 { return nil, fmt.Errorf("Invalid pinned certificate on line %d\n", index) } - name := line[:index] - fingerprint := line[index:] + name := bytes.ToLower(fields[0]) + fingerprint := bytes.ToLower(fields[1]) if len(fingerprint) < 64 { return nil, fmt.Errorf("Invalid fingerprint on line %d\n", index) @@ -173,6 +175,11 @@ func NewPinnedClient(store *CertificateStore) *http.Client { return &http.Client{ Transport: &http.Transport{ TLSClientConfig: config, + Dial: (&net.Dialer{ + Timeout: 5 * time.Second, + }).Dial, + TLSHandshakeTimeout: 5 * time.Second, + IdleConnTimeout: 10 * time.Second, }, } } diff --git a/client/main.go b/client/main.go index 18b6dfd..399a7b4 100644 --- a/client/main.go +++ b/client/main.go @@ -74,20 +74,26 @@ func run(_ context.Context, cmd *cli.Command) error { return fmt.Errorf("failed to load SSH config: %v", err) } - hostname := hostArg - identity := cmd.String("identity") + var ( + port string + hostname = hostArg + identity = cmd.String("identity") + ) - if cfg != nil { - if found, _ := cfg.Get(hostArg, "HostName"); found != "" { - hostname = found + if index := strings.Index(hostArg, ":"); index != -1 { + hostname = hostname[:index] + port = hostArg[index+1:] + } - if port := strings.Index(hostname, ":"); port != -1 { - hostname = hostname[:port] - } - } + if found, _ := cfg.Get(hostname, "IdentityFile"); found != "" { + identity = found + } - if found, _ := cfg.Get(hostArg, "IdentityFile"); found != "" { - identity = found + if found, _ := cfg.Get(hostname, "HostName"); found != "" { + hostname = found + + if index := strings.Index(hostname, ":"); index != -1 { + hostname = hostname[:index] } } @@ -95,6 +101,10 @@ func run(_ context.Context, cmd *cli.Command) error { return errors.New("missing or invalid host") } + if port != "" { + hostname += ":" + port + } + log.Printf("Using host: %s\n", hostname) if identity == "" {