Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 7 additions & 25 deletions cmd/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ import (
"github.com/spf13/cobra"
)

// BrowserLoginFunc is the function used to perform browser-based login.
// Tests may override this to avoid opening a real browser.
var BrowserLoginFunc = func(host string, ios *iostreams.IOStreams) (string, error) {
return internalauth.BrowserLogin(host, ios)
}

// DeviceLoginFunc is the function used to perform device-code login.
// Tests may override this to avoid real network calls.
var DeviceLoginFunc = func(host string, ios *iostreams.IOStreams) (string, error) {
Expand All @@ -41,16 +35,16 @@ func NewLoginCmd(f *cmdutil.Factory) *cobra.Command {
Use: "login",
Short: "Log in to KeeperHub",
Args: cobra.NoArgs,
Long: `Authenticate with KeeperHub. By default opens a browser for OAuth.
Use --no-browser for device code flow on headless or SSH environments.
Long: `Authenticate with KeeperHub using the device code flow.
Opens a browser to confirm a one-time code.
Use --with-token to read an API key from stdin for non-interactive automation.

See also: kh auth status, kh auth logout`,
Example: ` # Log in via browser
Example: ` # Log in (device code flow)
kh auth login

# Log in on a headless machine
kh auth login --no-browser`,
# Log in with an API key (non-interactive)
echo "kh_xxx" | kh auth login --with-token`,
RunE: func(cmd *cobra.Command, args []string) error {
hosts, err := config.ReadHosts()
if err != nil {
Expand All @@ -66,13 +60,11 @@ See also: kh auth status, kh auth logout`,
envHost := os.Getenv("KH_HOST")
host := hosts.ActiveHost(flagHost, envHost)

noBrowser, _ := cmd.Flags().GetBool("no-browser")
withToken, _ := cmd.Flags().GetBool("with-token")

var token string

switch {
case withToken:
if withToken {
t, readErr := internalauth.ReadTokenFromStdin(f.IOStreams)
if readErr != nil {
return readErr
Expand All @@ -81,20 +73,12 @@ See also: kh auth status, kh auth logout`,
return fmt.Errorf("storing token: %w", err)
}
token = t

case noBrowser:
} else {
t, loginErr := DeviceLoginFunc(host, f.IOStreams)
if loginErr != nil {
return loginErr
}
token = t

default:
t, loginErr := BrowserLoginFunc(host, f.IOStreams)
if loginErr != nil {
return loginErr
}
token = t
}

if token == "" {
Expand All @@ -103,7 +87,6 @@ See also: kh auth status, kh auth logout`,

info, err := FetchTokenInfoFunc(host, token)
if err != nil {
// Non-fatal: login succeeded but we can't fetch user details
fmt.Fprintf(f.IOStreams.Out, "Logged in to %s\n", host)
return nil
}
Expand All @@ -113,7 +96,6 @@ See also: kh auth status, kh auth logout`,
},
}

cmd.Flags().Bool("no-browser", false, "Do not open a browser window")
cmd.Flags().Bool("with-token", false, "Read token from stdin")

return cmd
Expand Down
52 changes: 6 additions & 46 deletions cmd/auth/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,51 +10,12 @@ import (
"github.com/keeperhub/cli/pkg/iostreams"
)

func TestLoginCmd_BrowserFlow(t *testing.T) {
// Isolate from real config.yml so ActiveHost returns the hardcoded default.
func TestLoginCmd_DefaultDeviceFlow(t *testing.T) {
t.Setenv("XDG_CONFIG_HOME", t.TempDir())

ios, buf, _, _ := iostreams.Test()

browserCalled := false
auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) {
browserCalled = true
return "test-token", nil
}
auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) {
t.Fatal("DeviceLogin should not be called in browser flow")
return "", nil
}
auth.SetTokenFunc = func(host, token string) error { return nil }
auth.FetchTokenInfoFunc = func(host, token string) (internalauth.TokenInfo, error) {
return internalauth.TokenInfo{Email: "user@example.com"}, nil
}

f := &cmdutil.Factory{IOStreams: ios}
cmd := auth.NewLoginCmd(f)
cmd.SetArgs([]string{})

err := cmd.Execute()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !browserCalled {
t.Error("expected BrowserLogin to be called")
}
out := buf.String()
if !strings.Contains(out, "app.keeperhub.com") {
t.Errorf("expected host in output, got: %q", out)
}
}

func TestLoginCmd_NoBrowserFlag(t *testing.T) {
ios, _, _, _ := iostreams.Test()

deviceCalled := false
auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) {
t.Fatal("BrowserLogin should not be called with --no-browser")
return "", nil
}
auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) {
deviceCalled = true
return "test-token", nil
Expand All @@ -66,7 +27,7 @@ func TestLoginCmd_NoBrowserFlag(t *testing.T) {

f := &cmdutil.Factory{IOStreams: ios}
cmd := auth.NewLoginCmd(f)
cmd.SetArgs([]string{"--no-browser"})
cmd.SetArgs([]string{})

err := cmd.Execute()
if err != nil {
Expand All @@ -75,6 +36,10 @@ func TestLoginCmd_NoBrowserFlag(t *testing.T) {
if !deviceCalled {
t.Error("expected DeviceLogin to be called")
}
out := buf.String()
if !strings.Contains(out, "app.keeperhub.com") {
t.Errorf("expected host in output, got: %q", out)
}
}

func TestLoginCmd_WithTokenFlag(t *testing.T) {
Expand All @@ -83,10 +48,6 @@ func TestLoginCmd_WithTokenFlag(t *testing.T) {

storeHost := ""
storeToken := ""
auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) {
t.Fatal("BrowserLogin should not be called with --with-token")
return "", nil
}
auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) {
t.Fatal("DeviceLogin should not be called with --with-token")
return "", nil
Expand Down Expand Up @@ -124,7 +85,6 @@ func TestLoginCmd_WithTokenFlag_EmptyStdin(t *testing.T) {
ios, _, _, _ := iostreams.Test()
ios.In = strings.NewReader("")

auth.BrowserLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { return "", nil }
auth.DeviceLoginFunc = func(host string, ios2 *iostreams.IOStreams) (string, error) { return "", nil }
auth.SetTokenFunc = func(host, token string) error { return nil }
auth.FetchTokenInfoFunc = func(host, token string) (internalauth.TokenInfo, error) {
Expand Down
49 changes: 41 additions & 8 deletions internal/auth/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/keeperhub/cli/internal/config"
khhttp "github.com/keeperhub/cli/internal/http"
"github.com/keeperhub/cli/pkg/iostreams"
)
Expand All @@ -33,7 +36,16 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
defer cancel()

codeResp, err := requestDeviceCode(ctx, host)
// Load per-host headers (e.g. CF-Access) from hosts.yml.
var hostHeaders map[string]string
if hostsCfg, hostsErr := config.ReadHosts(); hostsErr == nil {
if entry, ok := hostsCfg.HostEntry(host); ok {
hostHeaders = entry.Headers
}
}

baseURL := khhttp.BuildBaseURL(host)
codeResp, err := requestDeviceCode(ctx, baseURL, hostHeaders)
if err != nil {
return "", err
}
Expand All @@ -46,7 +58,7 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) {
interval = 5 * time.Second
}

token, err := pollDeviceToken(ctx, host, codeResp.DeviceCode, interval)
token, err := pollDeviceToken(ctx, baseURL, codeResp.DeviceCode, interval, hostHeaders)
if err != nil {
return "", err
}
Expand All @@ -58,18 +70,22 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) {
return token, nil
}

func requestDeviceCode(ctx context.Context, host string) (deviceCodeResponse, error) {
func requestDeviceCode(ctx context.Context, baseURL string, headers map[string]string) (deviceCodeResponse, error) {
body, err := json.Marshal(map[string]string{"client_id": "kh-cli"})
if err != nil {
return deviceCodeResponse{}, err
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost,
khhttp.BuildBaseURL(host)+"/api/auth/device/code", bytes.NewReader(body))
baseURL+"/api/auth/device/code", bytes.NewReader(body))
if err != nil {
return deviceCodeResponse{}, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Origin", baseURL)
for k, v := range headers {
req.Header.Set(k, v)
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
Expand All @@ -89,15 +105,15 @@ func requestDeviceCode(ctx context.Context, host string) (deviceCodeResponse, er
return codeResp, nil
}

func pollDeviceToken(ctx context.Context, host, deviceCode string, interval time.Duration) (string, error) {
func pollDeviceToken(ctx context.Context, baseURL, deviceCode string, interval time.Duration, headers map[string]string) (string, error) {
for {
select {
case <-ctx.Done():
return "", errors.New("timed out waiting for device authorisation")
case <-time.After(interval):
}

tokenResp, err := checkDeviceToken(ctx, host, deviceCode)
tokenResp, err := checkDeviceToken(ctx, baseURL, deviceCode, headers)
if err != nil {
return "", err
}
Expand All @@ -121,7 +137,7 @@ func pollDeviceToken(ctx context.Context, host, deviceCode string, interval time
}
}

func checkDeviceToken(ctx context.Context, host, deviceCode string) (deviceTokenResponse, error) {
func checkDeviceToken(ctx context.Context, baseURL, deviceCode string, headers map[string]string) (deviceTokenResponse, error) {
body, err := json.Marshal(map[string]string{
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": deviceCode,
Expand All @@ -132,11 +148,15 @@ func checkDeviceToken(ctx context.Context, host, deviceCode string) (deviceToken
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost,
khhttp.BuildBaseURL(host)+"/api/auth/device/token", bytes.NewReader(body))
baseURL+"/api/auth/device/token", bytes.NewReader(body))
if err != nil {
return deviceTokenResponse{}, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Origin", baseURL)
for k, v := range headers {
req.Header.Set(k, v)
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
Expand All @@ -151,3 +171,16 @@ func checkDeviceToken(ctx context.Context, host, deviceCode string) (deviceToken

return tokenResp, nil
}

// ReadTokenFromStdin reads a token from ios.In, trims whitespace, and returns it.
func ReadTokenFromStdin(ios *iostreams.IOStreams) (string, error) {
data, err := io.ReadAll(ios.In)
if err != nil {
return "", fmt.Errorf("reading token from stdin: %w", err)
}
token := strings.TrimSpace(string(data))
if token == "" {
return "", errors.New("no token provided on stdin")
}
return token, nil
}
Loading
Loading