diff --git a/client/main.go b/client/main.go index 2162ebb..abba850 100644 --- a/client/main.go +++ b/client/main.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/coalaura/logger" "github.com/urfave/cli/v3" @@ -46,6 +47,76 @@ func main() { } func run(_ context.Context, cmd *cli.Command) error { + args := cmd.Args().Slice() + if len(args) != 2 { + return errors.New("Usage: up [options] ") + } + + fileArg := args[0] + hostArg := args[1] + + path, err := filepath.Abs(fileArg) + if err != nil { + return fmt.Errorf("failed to get file path: %v", err) + } + + log.Printf("Using file: %s\n", path) + + file, err := os.OpenFile(path, os.O_RDONLY, 0) + if err != nil { + return fmt.Errorf("failed to open file: %v", err) + } + + defer file.Close() + + cfg, err := LoadSSHConfig() + if err != nil { + return fmt.Errorf("failed to load SSH config: %v", err) + } + + hostname := hostArg + identity := cmd.String("identity") + + if cfg != nil { + if found, _ := cfg.Get(hostArg, "HostName"); found != "" { + hostname = found + + if port := strings.Index(hostname, ":"); port != -1 { + hostname = hostname[:port] + } + } + + if found, _ := cfg.Get(hostArg, "IdentityFile"); found != "" { + identity = found + } + } + + if hostname == "" { + return errors.New("missing or invalid host") + } + + log.Printf("Using host: %s\n", hostname) + + if identity == "" { + return errors.New("missing or invalid identity file") + } + + path, err = filepath.Abs(identity) + if err != nil { + return fmt.Errorf("failed to get identity file path: %v", err) + } + + log.Printf("Using identity file: %s\n", path) + + log.Printf("Loading key...") + + private, err := LoadPrivateKey(path) + if err != nil { + return fmt.Errorf("failed to load key: %v", err) + } + + public := base64.StdEncoding.EncodeToString(private.PublicKey().Marshal()) + log.Println("Loading certificate store...") store, err := LoadCertificateStore() @@ -55,68 +126,19 @@ func run(_ context.Context, cmd *cli.Command) error { client := NewPinnedClient(store) - path := cmd.String("key") - if path == "" { - return errors.New("missing private key") - } - - kPath, err := filepath.Abs(path) - if err != nil { - return fmt.Errorf("failed to get key path: %v", err) - } - - log.Printf("Using key: %s\n", kPath) - - path = cmd.String("file") - if path == "" { - return errors.New("missing file") - } - - fPath, err := filepath.Abs(path) - if err != nil { - return fmt.Errorf("failed to get file path: %v", err) - } - - log.Printf("Using file: %s\n", fPath) - - file, err := os.OpenFile(fPath, os.O_RDONLY, 0) - if err != nil { - return fmt.Errorf("failed to open file: %v", err) - } - - defer file.Close() - - target := cmd.String("target") - if target == "" { - return errors.New("missing target") - } - - target = fmt.Sprintf("https://%s", target) - - log.Printf("Using target: %s\n", target) - - log.Printf("Loading key...") - - private, err := LoadPrivateKey(kPath) - if err != nil { - return fmt.Errorf("failed to load key: %v", err) - } - - public := base64.StdEncoding.EncodeToString(private.PublicKey().Marshal()) - log.Println("Requesting challenge...") - challenge, err := RequestChallenge(client, target, public) + challenge, err := RequestChallenge(client, hostname, public) if err != nil { return err } log.Println("Completing challenge...") - response, err := CompleteChallenge(client, target, public, private, challenge) + response, err := CompleteChallenge(client, hostname, public, private, challenge) if err != nil { return err } - return SendFile(client, target, response.Token, file) + return SendFile(client, hostname, response.Token, file) } diff --git a/client/protocol.go b/client/protocol.go index 95ac750..fb398f9 100644 --- a/client/protocol.go +++ b/client/protocol.go @@ -17,7 +17,7 @@ import ( "golang.org/x/crypto/ssh" ) -func RequestChallenge(client *http.Client, target, public string) (*internal.AuthChallenge, error) { +func RequestChallenge(client *http.Client, hostname, public string) (*internal.AuthChallenge, error) { request, err := msgpack.Marshal(internal.AuthRequest{ Public: public, }) @@ -25,7 +25,7 @@ func RequestChallenge(client *http.Client, target, public string) (*internal.Aut return nil, fmt.Errorf("failed to marshal request: %v", err) } - response, err := client.Post(fmt.Sprintf("%s/request", target), "application/msgpack", bytes.NewReader(request)) + response, err := client.Post(fmt.Sprintf("https://%s/request", hostname), "application/msgpack", bytes.NewReader(request)) if err != nil { return nil, fmt.Errorf("failed to send request: %v", err) } @@ -45,7 +45,7 @@ func RequestChallenge(client *http.Client, target, public string) (*internal.Aut return &challenge, nil } -func CompleteChallenge(client *http.Client, target, public string, private ssh.Signer, challenge *internal.AuthChallenge) (*internal.AuthResponse, error) { +func CompleteChallenge(client *http.Client, hostname, public string, private ssh.Signer, challenge *internal.AuthChallenge) (*internal.AuthResponse, error) { rawChallenge, err := base64.StdEncoding.DecodeString(challenge.Challenge) if err != nil { return nil, fmt.Errorf("failed to decode challenge: %v", err) @@ -66,7 +66,7 @@ func CompleteChallenge(client *http.Client, target, public string, private ssh.S return nil, fmt.Errorf("failed to marshal request: %v", err) } - response, err := client.Post(fmt.Sprintf("%s/complete", target), "application/msgpack", bytes.NewReader(request)) + response, err := client.Post(fmt.Sprintf("https://%s/complete", hostname), "application/msgpack", bytes.NewReader(request)) if err != nil { return nil, fmt.Errorf("failed to send request: %v", err) } @@ -86,7 +86,7 @@ func CompleteChallenge(client *http.Client, target, public string, private ssh.S return &result, nil } -func SendFile(client *http.Client, target, token string, file *os.File) error { +func SendFile(client *http.Client, hostname, token string, file *os.File) error { stat, err := file.Stat() if err != nil { return fmt.Errorf("failed to stat file: %v", err) @@ -115,9 +115,9 @@ func SendFile(client *http.Client, target, token string, file *os.File) error { writer.Close() }() - reader := NewProgressReader("uploading file", stat.Size(), pReader) + reader := NewProgressReader("Uploading file", stat.Size(), pReader) - request, err := http.NewRequest("POST", fmt.Sprintf("%s/receive", target), reader) + request, err := http.NewRequest("POST", fmt.Sprintf("https://%s/receive", hostname), reader) if err != nil { return fmt.Errorf("failed to create request: %v", err) } diff --git a/test.cmd b/test.cmd index aa5c325..6245071 100644 --- a/test.cmd +++ b/test.cmd @@ -1,3 +1,3 @@ @echo off -go run .\client --key example.key -f test.bin -t localhost:7966 \ No newline at end of file +go run .\client test.bin localhost:7966 --identity example.key \ No newline at end of file