diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 73f9e5b..a7a9420 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -12,12 +12,6 @@ import ( "github.com/spf13/cobra" ) -// BrowserLoginFunc is the function used to perform browser-based login. -// Tests may override this to avoid opening a real browser. -var BrowserLoginFunc = func(host string, ios *iostreams.IOStreams) (string, error) { - return internalauth.BrowserLogin(host, ios) -} - // DeviceLoginFunc is the function used to perform device-code login. // Tests may override this to avoid real network calls. var DeviceLoginFunc = func(host string, ios *iostreams.IOStreams) (string, error) { @@ -41,16 +35,16 @@ func NewLoginCmd(f *cmdutil.Factory) *cobra.Command { Use: "login", Short: "Log in to KeeperHub", Args: cobra.NoArgs, - Long: `Authenticate with KeeperHub. By default opens a browser for OAuth. -Use --no-browser for device code flow on headless or SSH environments. + Long: `Authenticate with KeeperHub using the device code flow. +Opens a browser to confirm a one-time code. Use --with-token to read an API key from stdin for non-interactive automation. See also: kh auth status, kh auth logout`, - Example: ` # Log in via browser + Example: ` # Log in (device code flow) kh auth login - # Log in on a headless machine - kh auth login --no-browser`, + # Log in with an API key (non-interactive) + echo "kh_xxx" | kh auth login --with-token`, RunE: func(cmd *cobra.Command, args []string) error { hosts, err := config.ReadHosts() if err != nil { @@ -66,13 +60,11 @@ See also: kh auth status, kh auth logout`, envHost := os.Getenv("KH_HOST") host := hosts.ActiveHost(flagHost, envHost) - noBrowser, _ := cmd.Flags().GetBool("no-browser") withToken, _ := cmd.Flags().GetBool("with-token") var token string - switch { - case withToken: + if withToken { t, readErr := internalauth.ReadTokenFromStdin(f.IOStreams) if readErr != nil { return readErr @@ -81,20 +73,12 @@ See also: kh auth status, kh auth logout`, return fmt.Errorf("storing token: %w", err) } token = t - - case noBrowser: + } else { t, loginErr := DeviceLoginFunc(host, f.IOStreams) if loginErr != nil { return loginErr } token = t - - default: - t, loginErr := BrowserLoginFunc(host, f.IOStreams) - if loginErr != nil { - return loginErr - } - token = t } if token == "" { @@ -103,7 +87,6 @@ See also: kh auth status, kh auth logout`, info, err := FetchTokenInfoFunc(host, token) if err != nil { - // Non-fatal: login succeeded but we can't fetch user details fmt.Fprintf(f.IOStreams.Out, "Logged in to %s\n", host) return nil } @@ -113,7 +96,6 @@ See also: kh auth status, kh auth logout`, }, } - cmd.Flags().Bool("no-browser", false, "Do not open a browser window") cmd.Flags().Bool("with-token", false, "Read token from stdin") return cmd diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index 8e44c14..97d59ca 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -10,51 +10,12 @@ import ( "github.com/keeperhub/cli/pkg/iostreams" ) -func TestLoginCmd_BrowserFlow(t *testing.T) { - // Isolate from real config.yml so ActiveHost returns the hardcoded default. +func TestLoginCmd_DefaultDeviceFlow(t *testing.T) { t.Setenv("XDG_CONFIG_HOME", t.TempDir()) ios, buf, _, _ := iostreams.Test() - browserCalled := false - auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { - browserCalled = true - return "test-token", nil - } - auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { - t.Fatal("DeviceLogin should not be called in browser flow") - return "", nil - } - auth.SetTokenFunc = func(host, token string) error { return nil } - auth.FetchTokenInfoFunc = func(host, token string) (internalauth.TokenInfo, error) { - return internalauth.TokenInfo{Email: "user@example.com"}, nil - } - - f := &cmdutil.Factory{IOStreams: ios} - cmd := auth.NewLoginCmd(f) - cmd.SetArgs([]string{}) - - err := cmd.Execute() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !browserCalled { - t.Error("expected BrowserLogin to be called") - } - out := buf.String() - if !strings.Contains(out, "app.keeperhub.com") { - t.Errorf("expected host in output, got: %q", out) - } -} - -func TestLoginCmd_NoBrowserFlag(t *testing.T) { - ios, _, _, _ := iostreams.Test() - deviceCalled := false - auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { - t.Fatal("BrowserLogin should not be called with --no-browser") - return "", nil - } auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { deviceCalled = true return "test-token", nil @@ -66,7 +27,7 @@ func TestLoginCmd_NoBrowserFlag(t *testing.T) { f := &cmdutil.Factory{IOStreams: ios} cmd := auth.NewLoginCmd(f) - cmd.SetArgs([]string{"--no-browser"}) + cmd.SetArgs([]string{}) err := cmd.Execute() if err != nil { @@ -75,6 +36,10 @@ func TestLoginCmd_NoBrowserFlag(t *testing.T) { if !deviceCalled { t.Error("expected DeviceLogin to be called") } + out := buf.String() + if !strings.Contains(out, "app.keeperhub.com") { + t.Errorf("expected host in output, got: %q", out) + } } func TestLoginCmd_WithTokenFlag(t *testing.T) { @@ -83,10 +48,6 @@ func TestLoginCmd_WithTokenFlag(t *testing.T) { storeHost := "" storeToken := "" - auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { - t.Fatal("BrowserLogin should not be called with --with-token") - return "", nil - } auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { t.Fatal("DeviceLogin should not be called with --with-token") return "", nil @@ -124,7 +85,6 @@ func TestLoginCmd_WithTokenFlag_EmptyStdin(t *testing.T) { ios, _, _, _ := iostreams.Test() ios.In = strings.NewReader("") - auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { return "", nil } auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { return "", nil } auth.SetTokenFunc = func(host, token string) error { return nil } auth.FetchTokenInfoFunc = func(host, token string) (internalauth.TokenInfo, error) { diff --git a/internal/auth/device.go b/internal/auth/device.go index 03bb1ab..5d04fca 100644 --- a/internal/auth/device.go +++ b/internal/auth/device.go @@ -6,9 +6,12 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" + "strings" "time" + "github.com/keeperhub/cli/internal/config" khhttp "github.com/keeperhub/cli/internal/http" "github.com/keeperhub/cli/pkg/iostreams" ) @@ -33,7 +36,16 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) defer cancel() - codeResp, err := requestDeviceCode(ctx, host) + // Load per-host headers (e.g. CF-Access) from hosts.yml. + var hostHeaders map[string]string + if hostsCfg, hostsErr := config.ReadHosts(); hostsErr == nil { + if entry, ok := hostsCfg.HostEntry(host); ok { + hostHeaders = entry.Headers + } + } + + baseURL := khhttp.BuildBaseURL(host) + codeResp, err := requestDeviceCode(ctx, baseURL, hostHeaders) if err != nil { return "", err } @@ -46,7 +58,7 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) { interval = 5 * time.Second } - token, err := pollDeviceToken(ctx, host, codeResp.DeviceCode, interval) + token, err := pollDeviceToken(ctx, baseURL, codeResp.DeviceCode, interval, hostHeaders) if err != nil { return "", err } @@ -58,18 +70,22 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) { return token, nil } -func requestDeviceCode(ctx context.Context, host string) (deviceCodeResponse, error) { +func requestDeviceCode(ctx context.Context, baseURL string, headers map[string]string) (deviceCodeResponse, error) { body, err := json.Marshal(map[string]string{"client_id": "kh-cli"}) if err != nil { return deviceCodeResponse{}, err } req, err := http.NewRequestWithContext(ctx, http.MethodPost, - khhttp.BuildBaseURL(host)+"/api/auth/device/code", bytes.NewReader(body)) + baseURL+"/api/auth/device/code", bytes.NewReader(body)) if err != nil { return deviceCodeResponse{}, err } req.Header.Set("Content-Type", "application/json") + req.Header.Set("Origin", baseURL) + for k, v := range headers { + req.Header.Set(k, v) + } resp, err := http.DefaultClient.Do(req) if err != nil { @@ -89,7 +105,7 @@ func requestDeviceCode(ctx context.Context, host string) (deviceCodeResponse, er return codeResp, nil } -func pollDeviceToken(ctx context.Context, host, deviceCode string, interval time.Duration) (string, error) { +func pollDeviceToken(ctx context.Context, baseURL, deviceCode string, interval time.Duration, headers map[string]string) (string, error) { for { select { case <-ctx.Done(): @@ -97,7 +113,7 @@ func pollDeviceToken(ctx context.Context, host, deviceCode string, interval time case <-time.After(interval): } - tokenResp, err := checkDeviceToken(ctx, host, deviceCode) + tokenResp, err := checkDeviceToken(ctx, baseURL, deviceCode, headers) if err != nil { return "", err } @@ -121,7 +137,7 @@ func pollDeviceToken(ctx context.Context, host, deviceCode string, interval time } } -func checkDeviceToken(ctx context.Context, host, deviceCode string) (deviceTokenResponse, error) { +func checkDeviceToken(ctx context.Context, baseURL, deviceCode string, headers map[string]string) (deviceTokenResponse, error) { body, err := json.Marshal(map[string]string{ "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": deviceCode, @@ -132,11 +148,15 @@ func checkDeviceToken(ctx context.Context, host, deviceCode string) (deviceToken } req, err := http.NewRequestWithContext(ctx, http.MethodPost, - khhttp.BuildBaseURL(host)+"/api/auth/device/token", bytes.NewReader(body)) + baseURL+"/api/auth/device/token", bytes.NewReader(body)) if err != nil { return deviceTokenResponse{}, err } req.Header.Set("Content-Type", "application/json") + req.Header.Set("Origin", baseURL) + for k, v := range headers { + req.Header.Set(k, v) + } resp, err := http.DefaultClient.Do(req) if err != nil { @@ -151,3 +171,16 @@ func checkDeviceToken(ctx context.Context, host, deviceCode string) (deviceToken return tokenResp, nil } + +// ReadTokenFromStdin reads a token from ios.In, trims whitespace, and returns it. +func ReadTokenFromStdin(ios *iostreams.IOStreams) (string, error) { + data, err := io.ReadAll(ios.In) + if err != nil { + return "", fmt.Errorf("reading token from stdin: %w", err) + } + token := strings.TrimSpace(string(data)) + if token == "" { + return "", errors.New("no token provided on stdin") + } + return token, nil +} diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go deleted file mode 100644 index 131b67f..0000000 --- a/internal/auth/oauth.go +++ /dev/null @@ -1,128 +0,0 @@ -package auth - -import ( - "context" - "errors" - "fmt" - "io" - "net" - "net/http" - "os/exec" - "runtime" - "strings" - "time" - - khhttp "github.com/keeperhub/cli/internal/http" - "github.com/keeperhub/cli/pkg/iostreams" -) - -// browserOpener opens a URL in the default browser. Tests override this. -var browserOpener = openBrowser - -func openBrowser(url string) error { - var cmd string - var args []string - switch runtime.GOOS { - case "darwin": - cmd = "open" - args = []string{url} - case "windows": - cmd = "rundll32" - args = []string{"url.dll,FileProtocolHandler", url} - default: - cmd = "xdg-open" - args = []string{url} - } - return exec.Command(cmd, args...).Start() -} - -// BrowserLogin starts a localhost OAuth callback server, opens the browser to -// authenticate via GitHub OAuth, captures the session token from the callback, -// stores it in the keyring, and returns the token. -func BrowserLogin(host string, ios *iostreams.IOStreams) (string, error) { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return "", fmt.Errorf("starting callback server: %w", err) - } - port := listener.Addr().(*net.TCPAddr).Port - - tokenCh := make(chan string, 1) - errCh := make(chan error, 1) - - mux := http.NewServeMux() - srv := &http.Server{ - Handler: mux, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - } - - mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { - token := r.URL.Query().Get("token") - if token == "" { - // Try cookie fallback - if cookie, cookieErr := r.Cookie("better-auth.session_token"); cookieErr == nil { - token = cookie.Value - } - } - if token == "" { - http.Error(w, "No token received", http.StatusBadRequest) - errCh <- errors.New("no token in callback") - return - } - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, "

Authentication successful!

You can close this tab.

") - tokenCh <- token - }) - - go func() { - if serveErr := srv.Serve(listener); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) { - errCh <- serveErr - } - }() - - authURL := fmt.Sprintf("%s/api/auth/sign-in/social?provider=github&callbackURL=http://127.0.0.1:%d/callback", khhttp.BuildBaseURL(host), port) - fmt.Fprintf(ios.Out, "Opening browser to authenticate...\n") - - if err := browserOpener(authURL); err != nil { - _ = srv.Close() - return "", fmt.Errorf("opening browser: %w", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() - - var token string - select { - case token = <-tokenCh: - case err = <-errCh: - _ = srv.Close() - return "", err - case <-ctx.Done(): - _ = srv.Close() - return "", errors.New("timed out waiting for browser authentication") - } - - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 3*time.Second) - defer shutdownCancel() - _ = srv.Shutdown(shutdownCtx) - - if storeErr := SetToken(host, token); storeErr != nil { - return "", fmt.Errorf("storing token: %w", storeErr) - } - - return token, nil -} - -// ReadTokenFromStdin reads a token from ios.In, trims whitespace, and returns it. -// Returns an error if the input is empty after trimming. -func ReadTokenFromStdin(ios *iostreams.IOStreams) (string, error) { - data, err := io.ReadAll(ios.In) - if err != nil { - return "", fmt.Errorf("reading token from stdin: %w", err) - } - token := strings.TrimSpace(string(data)) - if token == "" { - return "", errors.New("no token provided on stdin") - } - return token, nil -} diff --git a/internal/auth/oauth_test.go b/internal/auth/oauth_test.go deleted file mode 100644 index 2b08336..0000000 --- a/internal/auth/oauth_test.go +++ /dev/null @@ -1,137 +0,0 @@ -package auth - -import ( - "fmt" - "net/http" - "strings" - "testing" - "time" - - "github.com/keeperhub/cli/pkg/iostreams" - "github.com/stretchr/testify/require" -) - -// installBrowserCapture overrides browserOpener so it sends the auth URL on a channel. -// Must be called BEFORE the BrowserLogin goroutine is started. -func installBrowserCapture(t *testing.T) <-chan string { - t.Helper() - ch := make(chan string, 1) - browserOpener = func(url string) error { - ch <- url - return nil - } - t.Cleanup(func() { browserOpener = openBrowser }) - return ch -} - -func extractCallbackPort(authURL string) string { - // authURL looks like: https://host/api/auth/sign-in/social?provider=github&callbackURL=http://127.0.0.1:PORT/callback - idx := strings.Index(authURL, "callbackURL=") - if idx < 0 { - return "" - } - cb := authURL[idx+len("callbackURL="):] - portIdx := strings.LastIndex(cb, ":") - slashIdx := strings.Index(cb[portIdx:], "/") - return cb[portIdx+1 : portIdx+slashIdx] -} - -func TestBrowserLogin_CapturesToken(t *testing.T) { - overrideKeyring(t) - urlCh := installBrowserCapture(t) // set before goroutine starts - - ios, _, _, _ := iostreams.Test() - - tokenCh := make(chan string, 1) - errCh := make(chan error, 1) - - go func() { - tok, err := BrowserLogin("app.keeperhub.com", ios) - if err != nil { - errCh <- err - return - } - tokenCh <- tok - }() - - // Wait for browser opener to be called. - var authURL string - select { - case authURL = <-urlCh: - case <-time.After(5 * time.Second): - t.Fatal("browser opener was never called") - } - - port := extractCallbackPort(authURL) - require.NotEmpty(t, port) - - resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/callback?token=test_token_123", port)) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - resp.Body.Close() - - select { - case tok := <-tokenCh: - require.Equal(t, "test_token_123", tok) - case err := <-errCh: - t.Fatalf("BrowserLogin returned error: %v", err) - case <-time.After(5 * time.Second): - t.Fatal("timed out waiting for BrowserLogin to return") - } - - stored, err := GetToken("app.keeperhub.com") - require.NoError(t, err) - require.Equal(t, "test_token_123", stored) -} - -func TestBrowserLogin_NoTokenInCallback(t *testing.T) { - overrideKeyring(t) - urlCh := installBrowserCapture(t) - - ios, _, _, _ := iostreams.Test() - - errCh := make(chan error, 1) - go func() { - _, err := BrowserLogin("app.keeperhub.com", ios) - errCh <- err - }() - - var authURL string - select { - case authURL = <-urlCh: - case <-time.After(5 * time.Second): - t.Fatal("browser opener was never called") - } - - port := extractCallbackPort(authURL) - require.NotEmpty(t, port) - - resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/callback", port)) - require.NoError(t, err) - require.Equal(t, http.StatusBadRequest, resp.StatusCode) - resp.Body.Close() - - select { - case err := <-errCh: - require.Error(t, err) - case <-time.After(5 * time.Second): - t.Fatal("timed out waiting for BrowserLogin to return error") - } -} - -func TestReadTokenFromStdin(t *testing.T) { - ios, _, _, _ := iostreams.Test() - ios.In = strings.NewReader(" my_token_value\n ") - - tok, err := ReadTokenFromStdin(ios) - require.NoError(t, err) - require.Equal(t, "my_token_value", tok) -} - -func TestReadTokenFromStdin_Empty(t *testing.T) { - ios, _, _, _ := iostreams.Test() - ios.In = strings.NewReader(" \n ") - - _, err := ReadTokenFromStdin(ios) - require.Error(t, err) -}