From a1395ea8190e8a32ef8502a26a7f130a7c7192a9 Mon Sep 17 00:00:00 2001 From: Simon KP Date: Fri, 13 Mar 2026 21:28:55 +1100 Subject: [PATCH 1/5] fix: POST to Better Auth social sign-in endpoint instead of GET Better Auth's /sign-in/social is a POST endpoint returning a redirect URL. The CLI was opening it as a GET URL in the browser, causing a 404. Now POSTs to get the OAuth redirect URL, then opens that in the browser. --- internal/auth/oauth.go | 51 ++++++++++++++++++++++++++++++++++++- internal/auth/oauth_test.go | 9 +++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go index 131b67f..4c11d29 100644 --- a/internal/auth/oauth.go +++ b/internal/auth/oauth.go @@ -1,7 +1,9 @@ package auth import ( + "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -19,6 +21,9 @@ import ( // browserOpener opens a URL in the default browser. Tests override this. var browserOpener = openBrowser +// socialSignInURLFunc fetches the OAuth redirect URL from the server. Tests override this. +var socialSignInURLFunc = fetchSocialSignInURL + func openBrowser(url string) error { var cmd string var args []string @@ -80,7 +85,16 @@ func BrowserLogin(host string, ios *iostreams.IOStreams) (string, error) { } }() - authURL := fmt.Sprintf("%s/api/auth/sign-in/social?provider=github&callbackURL=http://127.0.0.1:%d/callback", khhttp.BuildBaseURL(host), port) + // Better Auth's /sign-in/social is a POST endpoint that returns the OAuth redirect URL. + callbackURL := fmt.Sprintf("http://127.0.0.1:%d/callback", port) + baseURL := khhttp.BuildBaseURL(host) + + authURL, postErr := socialSignInURLFunc(baseURL, "github", callbackURL) + if postErr != nil { + _ = srv.Close() + return "", fmt.Errorf("initiating social sign-in: %w", postErr) + } + fmt.Fprintf(ios.Out, "Opening browser to authenticate...\n") if err := browserOpener(authURL); err != nil { @@ -113,6 +127,41 @@ func BrowserLogin(host string, ios *iostreams.IOStreams) (string, error) { return token, nil } +// fetchSocialSignInURL POSTs to Better Auth's /sign-in/social endpoint and +// returns the OAuth provider redirect URL from the JSON response. +func fetchSocialSignInURL(baseURL, provider, callbackURL string) (string, error) { + body, err := json.Marshal(map[string]string{ + "provider": provider, + "callbackURL": callbackURL, + }) + if err != nil { + return "", err + } + + resp, err := http.Post(baseURL+"/api/auth/sign-in/social", "application/json", bytes.NewReader(body)) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("server returned %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + URL string `json:"url"` + Redirect bool `json:"redirect"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("decoding response: %w", err) + } + if result.URL == "" { + return "", errors.New("server returned empty redirect URL") + } + return result.URL, nil +} + // ReadTokenFromStdin reads a token from ios.In, trims whitespace, and returns it. // Returns an error if the input is empty after trimming. func ReadTokenFromStdin(ios *iostreams.IOStreams) (string, error) { diff --git a/internal/auth/oauth_test.go b/internal/auth/oauth_test.go index 2b08336..9c13171 100644 --- a/internal/auth/oauth_test.go +++ b/internal/auth/oauth_test.go @@ -21,6 +21,15 @@ func installBrowserCapture(t *testing.T) <-chan string { return nil } t.Cleanup(func() { browserOpener = openBrowser }) + + // Override socialSignInURLFunc to return callbackURL as the "redirect URL" + // so the browser opener receives a URL containing the callback port. + origSocial := socialSignInURLFunc + socialSignInURLFunc = func(baseURL, provider, callbackURL string) (string, error) { + return baseURL + "/api/auth/sign-in/social?provider=" + provider + "&callbackURL=" + callbackURL, nil + } + t.Cleanup(func() { socialSignInURLFunc = origSocial }) + return ch } From 5635884c62350c0cae73724767b099aebb549ad5 Mon Sep 17 00:00:00 2001 From: Simon KP Date: Fri, 13 Mar 2026 22:25:14 +1100 Subject: [PATCH 2/5] feat: browser-based OAuth via server-side relay with nonce Replace direct POST to Better Auth with a server-side flow: CLI opens /cli-auth page in the browser, which POSTs from the same origin (avoiding CORS/CF Access issues). A relay endpoint reads the HttpOnly session cookie and redirects to the CLI callback with the token. A cryptographic nonce is threaded through the entire flow to prevent relay URL forgery. --- internal/auth/oauth.go | 86 +++++++++++++------------------- internal/auth/oauth_test.go | 99 +++++++++++++++++++++++++++---------- 2 files changed, 106 insertions(+), 79 deletions(-) diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go index 4c11d29..ab96b77 100644 --- a/internal/auth/oauth.go +++ b/internal/auth/oauth.go @@ -1,9 +1,9 @@ package auth import ( - "bytes" "context" - "encoding/json" + "crypto/rand" + "encoding/hex" "errors" "fmt" "io" @@ -21,9 +21,6 @@ import ( // browserOpener opens a URL in the default browser. Tests override this. var browserOpener = openBrowser -// socialSignInURLFunc fetches the OAuth redirect URL from the server. Tests override this. -var socialSignInURLFunc = fetchSocialSignInURL - func openBrowser(url string) error { var cmd string var args []string @@ -41,9 +38,21 @@ func openBrowser(url string) error { return exec.Command(cmd, args...).Start() } +func generateNonce() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + // BrowserLogin starts a localhost OAuth callback server, opens the browser to -// authenticate via GitHub OAuth, captures the session token from the callback, -// stores it in the keyring, and returns the token. +// the server's /cli-auth page which initiates GitHub OAuth from the same origin, +// captures the session token from the callback, stores it in the keyring, +// and returns the token. +// +// A one-time nonce is generated and threaded through the entire flow to prevent +// an attacker from constructing a valid relay URL without knowledge of the nonce. func BrowserLogin(host string, ios *iostreams.IOStreams) (string, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -51,6 +60,12 @@ func BrowserLogin(host string, ios *iostreams.IOStreams) (string, error) { } port := listener.Addr().(*net.TCPAddr).Port + nonce, err := generateNonce() + if err != nil { + _ = listener.Close() + return "", fmt.Errorf("generating nonce: %w", err) + } + tokenCh := make(chan string, 1) errCh := make(chan error, 1) @@ -62,9 +77,14 @@ func BrowserLogin(host string, ios *iostreams.IOStreams) (string, error) { } mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("nonce") != nonce { + http.Error(w, "Invalid nonce", http.StatusForbidden) + errCh <- errors.New("callback nonce mismatch") + return + } + token := r.URL.Query().Get("token") if token == "" { - // Try cookie fallback if cookie, cookieErr := r.Cookie("better-auth.session_token"); cookieErr == nil { token = cookie.Value } @@ -85,19 +105,16 @@ func BrowserLogin(host string, ios *iostreams.IOStreams) (string, error) { } }() - // Better Auth's /sign-in/social is a POST endpoint that returns the OAuth redirect URL. - callbackURL := fmt.Sprintf("http://127.0.0.1:%d/callback", port) + // Open the server-side /cli-auth page. It POSTs to /api/auth/sign-in/social + // from the same origin (no CORS), using /api/cli-auth/relay as the OAuth + // callbackURL so the server can read the HttpOnly session cookie and forward + // the token to our local callback server. baseURL := khhttp.BuildBaseURL(host) - - authURL, postErr := socialSignInURLFunc(baseURL, "github", callbackURL) - if postErr != nil { - _ = srv.Close() - return "", fmt.Errorf("initiating social sign-in: %w", postErr) - } + authPageURL := fmt.Sprintf("%s/cli-auth?provider=github&port=%d&nonce=%s", baseURL, port, nonce) fmt.Fprintf(ios.Out, "Opening browser to authenticate...\n") - if err := browserOpener(authURL); err != nil { + if err := browserOpener(authPageURL); err != nil { _ = srv.Close() return "", fmt.Errorf("opening browser: %w", err) } @@ -127,41 +144,6 @@ func BrowserLogin(host string, ios *iostreams.IOStreams) (string, error) { return token, nil } -// fetchSocialSignInURL POSTs to Better Auth's /sign-in/social endpoint and -// returns the OAuth provider redirect URL from the JSON response. -func fetchSocialSignInURL(baseURL, provider, callbackURL string) (string, error) { - body, err := json.Marshal(map[string]string{ - "provider": provider, - "callbackURL": callbackURL, - }) - if err != nil { - return "", err - } - - resp, err := http.Post(baseURL+"/api/auth/sign-in/social", "application/json", bytes.NewReader(body)) - if err != nil { - return "", err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("server returned %d: %s", resp.StatusCode, string(respBody)) - } - - var result struct { - URL string `json:"url"` - Redirect bool `json:"redirect"` - } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", fmt.Errorf("decoding response: %w", err) - } - if result.URL == "" { - return "", errors.New("server returned empty redirect URL") - } - return result.URL, nil -} - // ReadTokenFromStdin reads a token from ios.In, trims whitespace, and returns it. // Returns an error if the input is empty after trimming. func ReadTokenFromStdin(ios *iostreams.IOStreams) (string, error) { diff --git a/internal/auth/oauth_test.go b/internal/auth/oauth_test.go index 9c13171..c5bb7cd 100644 --- a/internal/auth/oauth_test.go +++ b/internal/auth/oauth_test.go @@ -3,6 +3,7 @@ package auth import ( "fmt" "net/http" + "regexp" "strings" "testing" "time" @@ -11,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -// installBrowserCapture overrides browserOpener so it sends the auth URL on a channel. +// installBrowserCapture overrides browserOpener so it sends the URL on a channel. // Must be called BEFORE the BrowserLogin goroutine is started. func installBrowserCapture(t *testing.T) <-chan string { t.Helper() @@ -21,33 +22,31 @@ func installBrowserCapture(t *testing.T) <-chan string { return nil } t.Cleanup(func() { browserOpener = openBrowser }) + return ch +} - // Override socialSignInURLFunc to return callbackURL as the "redirect URL" - // so the browser opener receives a URL containing the callback port. - origSocial := socialSignInURLFunc - socialSignInURLFunc = func(baseURL, provider, callbackURL string) (string, error) { - return baseURL + "/api/auth/sign-in/social?provider=" + provider + "&callbackURL=" + callbackURL, nil - } - t.Cleanup(func() { socialSignInURLFunc = origSocial }) +var portPattern = regexp.MustCompile(`port=(\d+)`) +var noncePattern = regexp.MustCompile(`nonce=([a-f0-9]{32})`) - return ch +func extractPort(cliAuthURL string) string { + m := portPattern.FindStringSubmatch(cliAuthURL) + if len(m) < 2 { + return "" + } + return m[1] } -func extractCallbackPort(authURL string) string { - // authURL looks like: https://host/api/auth/sign-in/social?provider=github&callbackURL=http://127.0.0.1:PORT/callback - idx := strings.Index(authURL, "callbackURL=") - if idx < 0 { +func extractNonce(cliAuthURL string) string { + m := noncePattern.FindStringSubmatch(cliAuthURL) + if len(m) < 2 { return "" } - cb := authURL[idx+len("callbackURL="):] - portIdx := strings.LastIndex(cb, ":") - slashIdx := strings.Index(cb[portIdx:], "/") - return cb[portIdx+1 : portIdx+slashIdx] + return m[1] } func TestBrowserLogin_CapturesToken(t *testing.T) { overrideKeyring(t) - urlCh := installBrowserCapture(t) // set before goroutine starts + urlCh := installBrowserCapture(t) ios, _, _, _ := iostreams.Test() @@ -63,18 +62,24 @@ func TestBrowserLogin_CapturesToken(t *testing.T) { tokenCh <- tok }() - // Wait for browser opener to be called. - var authURL string + var cliAuthURL string select { - case authURL = <-urlCh: + case cliAuthURL = <-urlCh: case <-time.After(5 * time.Second): t.Fatal("browser opener was never called") } - port := extractCallbackPort(authURL) + require.Contains(t, cliAuthURL, "/cli-auth") + require.Contains(t, cliAuthURL, "provider=github") + + port := extractPort(cliAuthURL) require.NotEmpty(t, port) - resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/callback?token=test_token_123", port)) + nonce := extractNonce(cliAuthURL) + require.NotEmpty(t, nonce, "nonce must be present in auth URL") + + // Simulate the relay redirecting to the CLI's local callback server. + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/callback?token=test_token_123&nonce=%s", port, nonce)) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) resp.Body.Close() @@ -93,6 +98,43 @@ func TestBrowserLogin_CapturesToken(t *testing.T) { require.Equal(t, "test_token_123", stored) } +func TestBrowserLogin_BadNonce(t *testing.T) { + overrideKeyring(t) + urlCh := installBrowserCapture(t) + + ios, _, _, _ := iostreams.Test() + + errCh := make(chan error, 1) + go func() { + _, err := BrowserLogin("app.keeperhub.com", ios) + errCh <- err + }() + + var cliAuthURL string + select { + case cliAuthURL = <-urlCh: + case <-time.After(5 * time.Second): + t.Fatal("browser opener was never called") + } + + port := extractPort(cliAuthURL) + require.NotEmpty(t, port) + + // Send a request with a wrong nonce. + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/callback?token=stolen&nonce=bad", port)) + require.NoError(t, err) + require.Equal(t, http.StatusForbidden, resp.StatusCode) + resp.Body.Close() + + select { + case err := <-errCh: + require.Error(t, err) + require.Contains(t, err.Error(), "nonce") + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for BrowserLogin to return error") + } +} + func TestBrowserLogin_NoTokenInCallback(t *testing.T) { overrideKeyring(t) urlCh := installBrowserCapture(t) @@ -105,17 +147,20 @@ func TestBrowserLogin_NoTokenInCallback(t *testing.T) { errCh <- err }() - var authURL string + var cliAuthURL string select { - case authURL = <-urlCh: + case cliAuthURL = <-urlCh: case <-time.After(5 * time.Second): t.Fatal("browser opener was never called") } - port := extractCallbackPort(authURL) + port := extractPort(cliAuthURL) + nonce := extractNonce(cliAuthURL) require.NotEmpty(t, port) + require.NotEmpty(t, nonce) - resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/callback", port)) + // Valid nonce but no token. + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/callback?nonce=%s", port, nonce)) require.NoError(t, err) require.Equal(t, http.StatusBadRequest, resp.StatusCode) resp.Body.Close() From eb0b8ffa40a79c45042f468f041c9a163f9be0b8 Mon Sep 17 00:00:00 2001 From: Simon KP Date: Fri, 13 Mar 2026 22:36:56 +1100 Subject: [PATCH 3/5] feat: make device flow the default auth method Switch kh auth login to use device code flow (RFC 8628) by default, matching the pattern used by GitHub CLI. Browser OAuth is available via --browser flag. Device flow avoids CORS, cookie, and Cloudflare Access issues since the browser handles all server communication. Also adds CF-Access header support to device code/token requests for staging environments behind Cloudflare Access. --- cmd/auth/login.go | 21 +++++++++++---------- cmd/auth/login_test.go | 36 +++++++++++++++++++----------------- internal/auth/device.go | 27 +++++++++++++++++++++------ 3 files changed, 51 insertions(+), 33 deletions(-) diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 73f9e5b..0358e77 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -41,16 +41,17 @@ func NewLoginCmd(f *cmdutil.Factory) *cobra.Command { Use: "login", Short: "Log in to KeeperHub", Args: cobra.NoArgs, - Long: `Authenticate with KeeperHub. By default opens a browser for OAuth. -Use --no-browser for device code flow on headless or SSH environments. + Long: `Authenticate with KeeperHub using the device code flow. +Opens a browser to confirm a one-time code. +Use --browser for direct browser-based OAuth (useful if device flow is unavailable). Use --with-token to read an API key from stdin for non-interactive automation. See also: kh auth status, kh auth logout`, - Example: ` # Log in via browser + Example: ` # Log in (device code flow) kh auth login - # Log in on a headless machine - kh auth login --no-browser`, + # Log in via direct browser OAuth + kh auth login --browser`, RunE: func(cmd *cobra.Command, args []string) error { hosts, err := config.ReadHosts() if err != nil { @@ -66,7 +67,7 @@ See also: kh auth status, kh auth logout`, envHost := os.Getenv("KH_HOST") host := hosts.ActiveHost(flagHost, envHost) - noBrowser, _ := cmd.Flags().GetBool("no-browser") + useBrowser, _ := cmd.Flags().GetBool("browser") withToken, _ := cmd.Flags().GetBool("with-token") var token string @@ -82,15 +83,15 @@ See also: kh auth status, kh auth logout`, } token = t - case noBrowser: - t, loginErr := DeviceLoginFunc(host, f.IOStreams) + case useBrowser: + t, loginErr := BrowserLoginFunc(host, f.IOStreams) if loginErr != nil { return loginErr } token = t default: - t, loginErr := BrowserLoginFunc(host, f.IOStreams) + t, loginErr := DeviceLoginFunc(host, f.IOStreams) if loginErr != nil { return loginErr } @@ -113,7 +114,7 @@ See also: kh auth status, kh auth logout`, }, } - cmd.Flags().Bool("no-browser", false, "Do not open a browser window") + cmd.Flags().Bool("browser", false, "Use direct browser OAuth instead of device code flow") cmd.Flags().Bool("with-token", false, "Read token from stdin") return cmd diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index 8e44c14..2fdf6a3 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -10,20 +10,20 @@ import ( "github.com/keeperhub/cli/pkg/iostreams" ) -func TestLoginCmd_BrowserFlow(t *testing.T) { +func TestLoginCmd_DefaultDeviceFlow(t *testing.T) { // Isolate from real config.yml so ActiveHost returns the hardcoded default. t.Setenv("XDG_CONFIG_HOME", t.TempDir()) ios, buf, _, _ := iostreams.Test() - browserCalled := false + deviceCalled := false auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { - browserCalled = true - return "test-token", nil + t.Fatal("BrowserLogin should not be called in default flow") + return "", nil } auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { - t.Fatal("DeviceLogin should not be called in browser flow") - return "", nil + deviceCalled = true + return "test-token", nil } auth.SetTokenFunc = func(host, token string) error { return nil } auth.FetchTokenInfoFunc = func(host, token string) (internalauth.TokenInfo, error) { @@ -38,8 +38,8 @@ func TestLoginCmd_BrowserFlow(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !browserCalled { - t.Error("expected BrowserLogin to be called") + if !deviceCalled { + t.Error("expected DeviceLogin to be called") } out := buf.String() if !strings.Contains(out, "app.keeperhub.com") { @@ -47,17 +47,19 @@ func TestLoginCmd_BrowserFlow(t *testing.T) { } } -func TestLoginCmd_NoBrowserFlag(t *testing.T) { +func TestLoginCmd_BrowserFlag(t *testing.T) { + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + ios, _, _, _ := iostreams.Test() - deviceCalled := false + browserCalled := false auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { - t.Fatal("BrowserLogin should not be called with --no-browser") - return "", nil + browserCalled = true + return "test-token", nil } auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { - deviceCalled = true - return "test-token", nil + t.Fatal("DeviceLogin should not be called with --browser") + return "", nil } auth.SetTokenFunc = func(host, token string) error { return nil } auth.FetchTokenInfoFunc = func(host, token string) (internalauth.TokenInfo, error) { @@ -66,14 +68,14 @@ func TestLoginCmd_NoBrowserFlag(t *testing.T) { f := &cmdutil.Factory{IOStreams: ios} cmd := auth.NewLoginCmd(f) - cmd.SetArgs([]string{"--no-browser"}) + cmd.SetArgs([]string{"--browser"}) err := cmd.Execute() if err != nil { t.Fatalf("unexpected error: %v", err) } - if !deviceCalled { - t.Error("expected DeviceLogin to be called") + if !browserCalled { + t.Error("expected BrowserLogin to be called") } } diff --git a/internal/auth/device.go b/internal/auth/device.go index 03bb1ab..907bcea 100644 --- a/internal/auth/device.go +++ b/internal/auth/device.go @@ -9,6 +9,7 @@ import ( "net/http" "time" + "github.com/keeperhub/cli/internal/config" khhttp "github.com/keeperhub/cli/internal/http" "github.com/keeperhub/cli/pkg/iostreams" ) @@ -33,7 +34,15 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) defer cancel() - codeResp, err := requestDeviceCode(ctx, host) + // Load per-host headers (e.g. CF-Access) from hosts.yml. + var hostHeaders map[string]string + if hostsCfg, hostsErr := config.ReadHosts(); hostsErr == nil { + if entry, ok := hostsCfg.HostEntry(host); ok { + hostHeaders = entry.Headers + } + } + + codeResp, err := requestDeviceCode(ctx, host, hostHeaders) if err != nil { return "", err } @@ -46,7 +55,7 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) { interval = 5 * time.Second } - token, err := pollDeviceToken(ctx, host, codeResp.DeviceCode, interval) + token, err := pollDeviceToken(ctx, host, codeResp.DeviceCode, interval, hostHeaders) if err != nil { return "", err } @@ -58,7 +67,7 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) { return token, nil } -func requestDeviceCode(ctx context.Context, host string) (deviceCodeResponse, error) { +func requestDeviceCode(ctx context.Context, host string, headers map[string]string) (deviceCodeResponse, error) { body, err := json.Marshal(map[string]string{"client_id": "kh-cli"}) if err != nil { return deviceCodeResponse{}, err @@ -70,6 +79,9 @@ func requestDeviceCode(ctx context.Context, host string) (deviceCodeResponse, er return deviceCodeResponse{}, err } req.Header.Set("Content-Type", "application/json") + for k, v := range headers { + req.Header.Set(k, v) + } resp, err := http.DefaultClient.Do(req) if err != nil { @@ -89,7 +101,7 @@ func requestDeviceCode(ctx context.Context, host string) (deviceCodeResponse, er return codeResp, nil } -func pollDeviceToken(ctx context.Context, host, deviceCode string, interval time.Duration) (string, error) { +func pollDeviceToken(ctx context.Context, host, deviceCode string, interval time.Duration, headers map[string]string) (string, error) { for { select { case <-ctx.Done(): @@ -97,7 +109,7 @@ func pollDeviceToken(ctx context.Context, host, deviceCode string, interval time case <-time.After(interval): } - tokenResp, err := checkDeviceToken(ctx, host, deviceCode) + tokenResp, err := checkDeviceToken(ctx, host, deviceCode, headers) if err != nil { return "", err } @@ -121,7 +133,7 @@ func pollDeviceToken(ctx context.Context, host, deviceCode string, interval time } } -func checkDeviceToken(ctx context.Context, host, deviceCode string) (deviceTokenResponse, error) { +func checkDeviceToken(ctx context.Context, host, deviceCode string, headers map[string]string) (deviceTokenResponse, error) { body, err := json.Marshal(map[string]string{ "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": deviceCode, @@ -137,6 +149,9 @@ func checkDeviceToken(ctx context.Context, host, deviceCode string) (deviceToken return deviceTokenResponse{}, err } req.Header.Set("Content-Type", "application/json") + for k, v := range headers { + req.Header.Set(k, v) + } resp, err := http.DefaultClient.Do(req) if err != nil { From 9ff8747eea2992dee541ddd618989cb9c9f4a429 Mon Sep 17 00:00:00 2001 From: Simon KP Date: Fri, 13 Mar 2026 22:39:29 +1100 Subject: [PATCH 4/5] fix: add Origin header to device flow requests for Better Auth CSRF --- internal/auth/device.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/internal/auth/device.go b/internal/auth/device.go index 907bcea..0bf30eb 100644 --- a/internal/auth/device.go +++ b/internal/auth/device.go @@ -42,7 +42,8 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) { } } - codeResp, err := requestDeviceCode(ctx, host, hostHeaders) + baseURL := khhttp.BuildBaseURL(host) + codeResp, err := requestDeviceCode(ctx, baseURL, hostHeaders) if err != nil { return "", err } @@ -55,7 +56,7 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) { interval = 5 * time.Second } - token, err := pollDeviceToken(ctx, host, codeResp.DeviceCode, interval, hostHeaders) + token, err := pollDeviceToken(ctx, baseURL, codeResp.DeviceCode, interval, hostHeaders) if err != nil { return "", err } @@ -67,18 +68,19 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) { return token, nil } -func requestDeviceCode(ctx context.Context, host string, headers map[string]string) (deviceCodeResponse, error) { +func requestDeviceCode(ctx context.Context, baseURL string, headers map[string]string) (deviceCodeResponse, error) { body, err := json.Marshal(map[string]string{"client_id": "kh-cli"}) if err != nil { return deviceCodeResponse{}, err } req, err := http.NewRequestWithContext(ctx, http.MethodPost, - khhttp.BuildBaseURL(host)+"/api/auth/device/code", bytes.NewReader(body)) + baseURL+"/api/auth/device/code", bytes.NewReader(body)) if err != nil { return deviceCodeResponse{}, err } req.Header.Set("Content-Type", "application/json") + req.Header.Set("Origin", baseURL) for k, v := range headers { req.Header.Set(k, v) } @@ -101,7 +103,7 @@ func requestDeviceCode(ctx context.Context, host string, headers map[string]stri return codeResp, nil } -func pollDeviceToken(ctx context.Context, host, deviceCode string, interval time.Duration, headers map[string]string) (string, error) { +func pollDeviceToken(ctx context.Context, baseURL, deviceCode string, interval time.Duration, headers map[string]string) (string, error) { for { select { case <-ctx.Done(): @@ -109,7 +111,7 @@ func pollDeviceToken(ctx context.Context, host, deviceCode string, interval time case <-time.After(interval): } - tokenResp, err := checkDeviceToken(ctx, host, deviceCode, headers) + tokenResp, err := checkDeviceToken(ctx, baseURL, deviceCode, headers) if err != nil { return "", err } @@ -133,7 +135,7 @@ func pollDeviceToken(ctx context.Context, host, deviceCode string, interval time } } -func checkDeviceToken(ctx context.Context, host, deviceCode string, headers map[string]string) (deviceTokenResponse, error) { +func checkDeviceToken(ctx context.Context, baseURL, deviceCode string, headers map[string]string) (deviceTokenResponse, error) { body, err := json.Marshal(map[string]string{ "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": deviceCode, @@ -144,11 +146,12 @@ func checkDeviceToken(ctx context.Context, host, deviceCode string, headers map[ } req, err := http.NewRequestWithContext(ctx, http.MethodPost, - khhttp.BuildBaseURL(host)+"/api/auth/device/token", bytes.NewReader(body)) + baseURL+"/api/auth/device/token", bytes.NewReader(body)) if err != nil { return deviceTokenResponse{}, err } req.Header.Set("Content-Type", "application/json") + req.Header.Set("Origin", baseURL) for k, v := range headers { req.Header.Set(k, v) } From 11343bf287d96576407a1de5c83b52aa81441a1b Mon Sep 17 00:00:00 2001 From: Simon KP Date: Fri, 13 Mar 2026 23:20:59 +1100 Subject: [PATCH 5/5] refactor: remove dead browser OAuth flow - Remove --browser flag (server-side relay page was never built) - Delete oauth.go and oauth_test.go - Move ReadTokenFromStdin to device.go - Device flow is the only interactive auth method --- cmd/auth/login.go | 27 +---- cmd/auth/login_test.go | 42 -------- internal/auth/device.go | 15 +++ internal/auth/oauth.go | 159 ------------------------------ internal/auth/oauth_test.go | 191 ------------------------------------ 5 files changed, 19 insertions(+), 415 deletions(-) delete mode 100644 internal/auth/oauth.go delete mode 100644 internal/auth/oauth_test.go diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 0358e77..a7a9420 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -12,12 +12,6 @@ import ( "github.com/spf13/cobra" ) -// BrowserLoginFunc is the function used to perform browser-based login. -// Tests may override this to avoid opening a real browser. -var BrowserLoginFunc = func(host string, ios *iostreams.IOStreams) (string, error) { - return internalauth.BrowserLogin(host, ios) -} - // DeviceLoginFunc is the function used to perform device-code login. // Tests may override this to avoid real network calls. var DeviceLoginFunc = func(host string, ios *iostreams.IOStreams) (string, error) { @@ -43,15 +37,14 @@ func NewLoginCmd(f *cmdutil.Factory) *cobra.Command { Args: cobra.NoArgs, Long: `Authenticate with KeeperHub using the device code flow. Opens a browser to confirm a one-time code. -Use --browser for direct browser-based OAuth (useful if device flow is unavailable). Use --with-token to read an API key from stdin for non-interactive automation. See also: kh auth status, kh auth logout`, Example: ` # Log in (device code flow) kh auth login - # Log in via direct browser OAuth - kh auth login --browser`, + # Log in with an API key (non-interactive) + echo "kh_xxx" | kh auth login --with-token`, RunE: func(cmd *cobra.Command, args []string) error { hosts, err := config.ReadHosts() if err != nil { @@ -67,13 +60,11 @@ See also: kh auth status, kh auth logout`, envHost := os.Getenv("KH_HOST") host := hosts.ActiveHost(flagHost, envHost) - useBrowser, _ := cmd.Flags().GetBool("browser") withToken, _ := cmd.Flags().GetBool("with-token") var token string - switch { - case withToken: + if withToken { t, readErr := internalauth.ReadTokenFromStdin(f.IOStreams) if readErr != nil { return readErr @@ -82,15 +73,7 @@ See also: kh auth status, kh auth logout`, return fmt.Errorf("storing token: %w", err) } token = t - - case useBrowser: - t, loginErr := BrowserLoginFunc(host, f.IOStreams) - if loginErr != nil { - return loginErr - } - token = t - - default: + } else { t, loginErr := DeviceLoginFunc(host, f.IOStreams) if loginErr != nil { return loginErr @@ -104,7 +87,6 @@ See also: kh auth status, kh auth logout`, info, err := FetchTokenInfoFunc(host, token) if err != nil { - // Non-fatal: login succeeded but we can't fetch user details fmt.Fprintf(f.IOStreams.Out, "Logged in to %s\n", host) return nil } @@ -114,7 +96,6 @@ See also: kh auth status, kh auth logout`, }, } - cmd.Flags().Bool("browser", false, "Use direct browser OAuth instead of device code flow") cmd.Flags().Bool("with-token", false, "Read token from stdin") return cmd diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index 2fdf6a3..97d59ca 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -11,16 +11,11 @@ import ( ) func TestLoginCmd_DefaultDeviceFlow(t *testing.T) { - // Isolate from real config.yml so ActiveHost returns the hardcoded default. t.Setenv("XDG_CONFIG_HOME", t.TempDir()) ios, buf, _, _ := iostreams.Test() deviceCalled := false - auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { - t.Fatal("BrowserLogin should not be called in default flow") - return "", nil - } auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { deviceCalled = true return "test-token", nil @@ -47,48 +42,12 @@ func TestLoginCmd_DefaultDeviceFlow(t *testing.T) { } } -func TestLoginCmd_BrowserFlag(t *testing.T) { - t.Setenv("XDG_CONFIG_HOME", t.TempDir()) - - ios, _, _, _ := iostreams.Test() - - browserCalled := false - auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { - browserCalled = true - return "test-token", nil - } - auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { - t.Fatal("DeviceLogin should not be called with --browser") - return "", nil - } - auth.SetTokenFunc = func(host, token string) error { return nil } - auth.FetchTokenInfoFunc = func(host, token string) (internalauth.TokenInfo, error) { - return internalauth.TokenInfo{Email: "user@example.com"}, nil - } - - f := &cmdutil.Factory{IOStreams: ios} - cmd := auth.NewLoginCmd(f) - cmd.SetArgs([]string{"--browser"}) - - err := cmd.Execute() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !browserCalled { - t.Error("expected BrowserLogin to be called") - } -} - func TestLoginCmd_WithTokenFlag(t *testing.T) { ios, buf, _, _ := iostreams.Test() ios.In = strings.NewReader("my-token-from-stdin\n") storeHost := "" storeToken := "" - auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { - t.Fatal("BrowserLogin should not be called with --with-token") - return "", nil - } auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { t.Fatal("DeviceLogin should not be called with --with-token") return "", nil @@ -126,7 +85,6 @@ func TestLoginCmd_WithTokenFlag_EmptyStdin(t *testing.T) { ios, _, _, _ := iostreams.Test() ios.In = strings.NewReader("") - auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { return "", nil } auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { return "", nil } auth.SetTokenFunc = func(host, token string) error { return nil } auth.FetchTokenInfoFunc = func(host, token string) (internalauth.TokenInfo, error) { diff --git a/internal/auth/device.go b/internal/auth/device.go index 0bf30eb..5d04fca 100644 --- a/internal/auth/device.go +++ b/internal/auth/device.go @@ -6,7 +6,9 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" + "strings" "time" "github.com/keeperhub/cli/internal/config" @@ -169,3 +171,16 @@ func checkDeviceToken(ctx context.Context, baseURL, deviceCode string, headers m return tokenResp, nil } + +// ReadTokenFromStdin reads a token from ios.In, trims whitespace, and returns it. +func ReadTokenFromStdin(ios *iostreams.IOStreams) (string, error) { + data, err := io.ReadAll(ios.In) + if err != nil { + return "", fmt.Errorf("reading token from stdin: %w", err) + } + token := strings.TrimSpace(string(data)) + if token == "" { + return "", errors.New("no token provided on stdin") + } + return token, nil +} diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go deleted file mode 100644 index ab96b77..0000000 --- a/internal/auth/oauth.go +++ /dev/null @@ -1,159 +0,0 @@ -package auth - -import ( - "context" - "crypto/rand" - "encoding/hex" - "errors" - "fmt" - "io" - "net" - "net/http" - "os/exec" - "runtime" - "strings" - "time" - - khhttp "github.com/keeperhub/cli/internal/http" - "github.com/keeperhub/cli/pkg/iostreams" -) - -// browserOpener opens a URL in the default browser. Tests override this. -var browserOpener = openBrowser - -func openBrowser(url string) error { - var cmd string - var args []string - switch runtime.GOOS { - case "darwin": - cmd = "open" - args = []string{url} - case "windows": - cmd = "rundll32" - args = []string{"url.dll,FileProtocolHandler", url} - default: - cmd = "xdg-open" - args = []string{url} - } - return exec.Command(cmd, args...).Start() -} - -func generateNonce() (string, error) { - b := make([]byte, 16) - if _, err := rand.Read(b); err != nil { - return "", err - } - return hex.EncodeToString(b), nil -} - -// BrowserLogin starts a localhost OAuth callback server, opens the browser to -// the server's /cli-auth page which initiates GitHub OAuth from the same origin, -// captures the session token from the callback, stores it in the keyring, -// and returns the token. -// -// A one-time nonce is generated and threaded through the entire flow to prevent -// an attacker from constructing a valid relay URL without knowledge of the nonce. -func BrowserLogin(host string, ios *iostreams.IOStreams) (string, error) { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return "", fmt.Errorf("starting callback server: %w", err) - } - port := listener.Addr().(*net.TCPAddr).Port - - nonce, err := generateNonce() - if err != nil { - _ = listener.Close() - return "", fmt.Errorf("generating nonce: %w", err) - } - - tokenCh := make(chan string, 1) - errCh := make(chan error, 1) - - mux := http.NewServeMux() - srv := &http.Server{ - Handler: mux, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - } - - mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { - if r.URL.Query().Get("nonce") != nonce { - http.Error(w, "Invalid nonce", http.StatusForbidden) - errCh <- errors.New("callback nonce mismatch") - return - } - - token := r.URL.Query().Get("token") - if token == "" { - if cookie, cookieErr := r.Cookie("better-auth.session_token"); cookieErr == nil { - token = cookie.Value - } - } - if token == "" { - http.Error(w, "No token received", http.StatusBadRequest) - errCh <- errors.New("no token in callback") - return - } - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, "

Authentication successful!

You can close this tab.

") - tokenCh <- token - }) - - go func() { - if serveErr := srv.Serve(listener); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) { - errCh <- serveErr - } - }() - - // Open the server-side /cli-auth page. It POSTs to /api/auth/sign-in/social - // from the same origin (no CORS), using /api/cli-auth/relay as the OAuth - // callbackURL so the server can read the HttpOnly session cookie and forward - // the token to our local callback server. - baseURL := khhttp.BuildBaseURL(host) - authPageURL := fmt.Sprintf("%s/cli-auth?provider=github&port=%d&nonce=%s", baseURL, port, nonce) - - fmt.Fprintf(ios.Out, "Opening browser to authenticate...\n") - - if err := browserOpener(authPageURL); err != nil { - _ = srv.Close() - return "", fmt.Errorf("opening browser: %w", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() - - var token string - select { - case token = <-tokenCh: - case err = <-errCh: - _ = srv.Close() - return "", err - case <-ctx.Done(): - _ = srv.Close() - return "", errors.New("timed out waiting for browser authentication") - } - - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 3*time.Second) - defer shutdownCancel() - _ = srv.Shutdown(shutdownCtx) - - if storeErr := SetToken(host, token); storeErr != nil { - return "", fmt.Errorf("storing token: %w", storeErr) - } - - return token, nil -} - -// ReadTokenFromStdin reads a token from ios.In, trims whitespace, and returns it. -// Returns an error if the input is empty after trimming. -func ReadTokenFromStdin(ios *iostreams.IOStreams) (string, error) { - data, err := io.ReadAll(ios.In) - if err != nil { - return "", fmt.Errorf("reading token from stdin: %w", err) - } - token := strings.TrimSpace(string(data)) - if token == "" { - return "", errors.New("no token provided on stdin") - } - return token, nil -} diff --git a/internal/auth/oauth_test.go b/internal/auth/oauth_test.go deleted file mode 100644 index c5bb7cd..0000000 --- a/internal/auth/oauth_test.go +++ /dev/null @@ -1,191 +0,0 @@ -package auth - -import ( - "fmt" - "net/http" - "regexp" - "strings" - "testing" - "time" - - "github.com/keeperhub/cli/pkg/iostreams" - "github.com/stretchr/testify/require" -) - -// installBrowserCapture overrides browserOpener so it sends the URL on a channel. -// Must be called BEFORE the BrowserLogin goroutine is started. -func installBrowserCapture(t *testing.T) <-chan string { - t.Helper() - ch := make(chan string, 1) - browserOpener = func(url string) error { - ch <- url - return nil - } - t.Cleanup(func() { browserOpener = openBrowser }) - return ch -} - -var portPattern = regexp.MustCompile(`port=(\d+)`) -var noncePattern = regexp.MustCompile(`nonce=([a-f0-9]{32})`) - -func extractPort(cliAuthURL string) string { - m := portPattern.FindStringSubmatch(cliAuthURL) - if len(m) < 2 { - return "" - } - return m[1] -} - -func extractNonce(cliAuthURL string) string { - m := noncePattern.FindStringSubmatch(cliAuthURL) - if len(m) < 2 { - return "" - } - return m[1] -} - -func TestBrowserLogin_CapturesToken(t *testing.T) { - overrideKeyring(t) - urlCh := installBrowserCapture(t) - - ios, _, _, _ := iostreams.Test() - - tokenCh := make(chan string, 1) - errCh := make(chan error, 1) - - go func() { - tok, err := BrowserLogin("app.keeperhub.com", ios) - if err != nil { - errCh <- err - return - } - tokenCh <- tok - }() - - var cliAuthURL string - select { - case cliAuthURL = <-urlCh: - case <-time.After(5 * time.Second): - t.Fatal("browser opener was never called") - } - - require.Contains(t, cliAuthURL, "/cli-auth") - require.Contains(t, cliAuthURL, "provider=github") - - port := extractPort(cliAuthURL) - require.NotEmpty(t, port) - - nonce := extractNonce(cliAuthURL) - require.NotEmpty(t, nonce, "nonce must be present in auth URL") - - // Simulate the relay redirecting to the CLI's local callback server. - resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/callback?token=test_token_123&nonce=%s", port, nonce)) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - resp.Body.Close() - - select { - case tok := <-tokenCh: - require.Equal(t, "test_token_123", tok) - case err := <-errCh: - t.Fatalf("BrowserLogin returned error: %v", err) - case <-time.After(5 * time.Second): - t.Fatal("timed out waiting for BrowserLogin to return") - } - - stored, err := GetToken("app.keeperhub.com") - require.NoError(t, err) - require.Equal(t, "test_token_123", stored) -} - -func TestBrowserLogin_BadNonce(t *testing.T) { - overrideKeyring(t) - urlCh := installBrowserCapture(t) - - ios, _, _, _ := iostreams.Test() - - errCh := make(chan error, 1) - go func() { - _, err := BrowserLogin("app.keeperhub.com", ios) - errCh <- err - }() - - var cliAuthURL string - select { - case cliAuthURL = <-urlCh: - case <-time.After(5 * time.Second): - t.Fatal("browser opener was never called") - } - - port := extractPort(cliAuthURL) - require.NotEmpty(t, port) - - // Send a request with a wrong nonce. - resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/callback?token=stolen&nonce=bad", port)) - require.NoError(t, err) - require.Equal(t, http.StatusForbidden, resp.StatusCode) - resp.Body.Close() - - select { - case err := <-errCh: - require.Error(t, err) - require.Contains(t, err.Error(), "nonce") - case <-time.After(5 * time.Second): - t.Fatal("timed out waiting for BrowserLogin to return error") - } -} - -func TestBrowserLogin_NoTokenInCallback(t *testing.T) { - overrideKeyring(t) - urlCh := installBrowserCapture(t) - - ios, _, _, _ := iostreams.Test() - - errCh := make(chan error, 1) - go func() { - _, err := BrowserLogin("app.keeperhub.com", ios) - errCh <- err - }() - - var cliAuthURL string - select { - case cliAuthURL = <-urlCh: - case <-time.After(5 * time.Second): - t.Fatal("browser opener was never called") - } - - port := extractPort(cliAuthURL) - nonce := extractNonce(cliAuthURL) - require.NotEmpty(t, port) - require.NotEmpty(t, nonce) - - // Valid nonce but no token. - resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/callback?nonce=%s", port, nonce)) - require.NoError(t, err) - require.Equal(t, http.StatusBadRequest, resp.StatusCode) - resp.Body.Close() - - select { - case err := <-errCh: - require.Error(t, err) - case <-time.After(5 * time.Second): - t.Fatal("timed out waiting for BrowserLogin to return error") - } -} - -func TestReadTokenFromStdin(t *testing.T) { - ios, _, _, _ := iostreams.Test() - ios.In = strings.NewReader(" my_token_value\n ") - - tok, err := ReadTokenFromStdin(ios) - require.NoError(t, err) - require.Equal(t, "my_token_value", tok) -} - -func TestReadTokenFromStdin_Empty(t *testing.T) { - ios, _, _, _ := iostreams.Test() - ios.In = strings.NewReader(" \n ") - - _, err := ReadTokenFromStdin(ios) - require.Error(t, err) -}