diff --git a/commands/login.go b/commands/login.go index 75813dd1..20ac87c2 100644 --- a/commands/login.go +++ b/commands/login.go @@ -26,6 +26,7 @@ import ( "fmt" "io" "log" + "net" "os" "path/filepath" "strings" @@ -41,6 +42,7 @@ import ( "github.com/openpubkey/openpubkey/util" "github.com/openpubkey/opkssh/sshcert" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" ) type LoginCmd struct { @@ -58,6 +60,30 @@ type LoginCmd struct { principals []string } +func HasSSHAgent() (string, bool) { + return os.LookupEnv("SSH_AUTH_SOCK") +} + +func getLifetime(b []byte) (uint32, error) { + var claims struct { + IssuedAt int64 `json:"iat"` + } + if err := json.Unmarshal(b, &claims); err != nil { + return 0, fmt.Errorf("malformed refreshed ID token payload: %w", err) + } + return uint32(time.Duration(time.Until(time.Unix(claims.IssuedAt, 0).Add(23 * time.Hour)).Seconds())), nil +} + +func getExpiration(b []byte) (time.Duration, error) { + var claims struct { + Expiration int64 `json:"exp"` + } + if err := json.Unmarshal(b, &claims); err != nil { + return 0, fmt.Errorf("malformed refreshed ID token payload: %w", err) + } + return time.Until(time.Unix(claims.Expiration, 0)) - time.Minute, nil +} + func NewLogin(autoRefresh bool, logDir string, disableBrowserOpenArg bool, printIdTokenArg bool, providerArg string, keyPathArg string, providerFromLdFlags providers.OpenIdProvider) *LoginCmd { return &LoginCmd{ autoRefresh: autoRefresh, @@ -180,10 +206,20 @@ func (l *LoginCmd) Run(ctx context.Context) error { } } + var sshagent agent.Agent + if s, ok := HasSSHAgent(); ok { + conn, err := net.Dial("unix", s) + if err != nil { + return fmt.Errorf("failed to connect to ssh-agent socket: %w", err) + } + defer conn.Close() + sshagent = agent.NewClient(conn) + } + // Execute login command if l.autoRefresh { if providerRefreshable, ok := provider.(providers.RefreshableOpenIdProvider); ok { - err := LoginWithRefresh(ctx, providerRefreshable, l.printIdTokenArg, l.keyPathArg) + err := LoginWithRefresh(ctx, providerRefreshable, l.printIdTokenArg, l.keyPathArg, sshagent) if err != nil { return fmt.Errorf("error logging in: %w", err) } @@ -191,7 +227,7 @@ func (l *LoginCmd) Run(ctx context.Context) error { return fmt.Errorf("supplied OpenID Provider (%v) does not support auto-refresh and auto-refresh argument set to true", provider.Issuer()) } } else { - err := Login(ctx, provider, l.printIdTokenArg, l.keyPathArg) + err := Login(ctx, provider, l.printIdTokenArg, l.keyPathArg, sshagent) if err != nil { return fmt.Errorf("error logging in: %w", err) } @@ -199,8 +235,9 @@ func (l *LoginCmd) Run(ctx context.Context) error { return nil } -func login(ctx context.Context, provider client.OpenIdProvider, printIdToken bool, seckeyPath string) (*LoginCmd, error) { +func login(ctx context.Context, provider client.OpenIdProvider, printIdToken bool, seckeyPath string, sshagent agent.Agent) (*LoginCmd, error) { var err error + alg := jwa.ES256 signer, err := util.GenKeyPair(alg) if err != nil { @@ -225,22 +262,8 @@ func login(ctx context.Context, provider client.OpenIdProvider, printIdToken boo return nil, fmt.Errorf("failed to generate SSH cert: %w", err) } - // Write ssh secret key and public key to filesystem - if seckeyPath != "" { - // If we have set seckeyPath then write it there - if err := writeKeys(seckeyPath, seckeyPath+".pub", seckeySshPem, certBytes); err != nil { - return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err) - } - } else { - // If keyPath isn't set then write it to the default location - if err := writeKeysToSSHDir(seckeySshPem, certBytes); err != nil { - return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err) - } - } - if printIdToken { idTokenStr, err := PrettyIdToken(*pkt) - if err != nil { return nil, fmt.Errorf("failed to format ID Token: %w", err) } @@ -252,7 +275,50 @@ func login(ctx context.Context, provider client.OpenIdProvider, printIdToken boo if err != nil { return nil, fmt.Errorf("failed to parse ID Token: %w", err) } - fmt.Printf("Keys generated for identity\n%s\n", idStr) + + if _, ok := HasSSHAgent(); ok { + pubkey, _, _, _, err := ssh.ParseAuthorizedKey(certBytes) + if err != nil { + return nil, err + } + cert, ok := pubkey.(*ssh.Certificate) + if !ok { + return nil, fmt.Errorf("failed to cast to certificate") + } + + lifetime, err := getLifetime(pkt.Payload) + if err != nil { + return nil, err + } + + if err := sshagent.Add(agent.AddedKey{ + PrivateKey: signer, + Comment: "openpubkey key", + LifetimeSecs: lifetime, + Certificate: cert, + }); err != nil { + fmt.Println("failed to ssh-add key to agent: %w", err) + } + fmt.Printf("Keys generated and added to ssh-agent for identity\n%s\n", idStr) + } else { + // Write ssh secret key and public key to filesystem + if seckeyPath != "" { + // If we have set seckeyPath then write it there + if err := writeKeys(seckeyPath, seckeyPath+".pub", seckeySshPem, certBytes); err != nil { + return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err) + } + } else { + // If keyPath isn't set then write it to the default location + if err := writeKeysToSSHDir(seckeySshPem, certBytes); err != nil { + return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err) + } + // Write ssh secret key and public key to filesystem + if err := writeKeysToSSHDir(seckeySshPem, certBytes); err != nil { + return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err) + } + fmt.Printf("Keys generated for identity\n%s\n", idStr) + } + } return &LoginCmd{ pkt: pkt, @@ -265,8 +331,8 @@ func login(ctx context.Context, provider client.OpenIdProvider, printIdToken boo // Login performs the OIDC login procedure and creates the SSH certs/keys in the // default SSH key location. -func Login(ctx context.Context, provider client.OpenIdProvider, printIdToken bool, seckeyPath string) error { - _, err := login(ctx, provider, printIdToken, seckeyPath) +func Login(ctx context.Context, provider client.OpenIdProvider, printIdToken bool, seckeyPath string, agent agent.Agent) error { + _, err := login(ctx, provider, printIdToken, seckeyPath, agent) return err } @@ -275,21 +341,18 @@ func Login(ctx context.Context, provider client.OpenIdProvider, printIdToken boo // the PKT (and create new SSH certs) indefinitely as its token expires. This // function only returns if it encounters an error or if the supplied context is // cancelled. -func LoginWithRefresh(ctx context.Context, provider providers.RefreshableOpenIdProvider, printIdToken bool, seckeyPath string) error { - if loginResult, err := login(ctx, provider, printIdToken, seckeyPath); err != nil { +func LoginWithRefresh(ctx context.Context, provider providers.RefreshableOpenIdProvider, printIdToken bool, seckeyPath string, sshagent agent.Agent) error { + if loginResult, err := login(ctx, provider, printIdToken, seckeyPath, sshagent); err != nil { return err } else { - var claims struct { - Expiration int64 `json:"exp"` - } - if err := json.Unmarshal(loginResult.pkt.Payload, &claims); err != nil { + untilExpired, err := getExpiration(loginResult.pkt.Payload) + if err != nil { return err } for { // Sleep until a minute before expiration to give us time to refresh // the token and minimize any interruptions - untilExpired := time.Until(time.Unix(claims.Expiration, 0)) - time.Minute log.Printf("Waiting for %v before attempting to refresh id_token...", untilExpired) select { case <-time.After(untilExpired): @@ -309,16 +372,45 @@ func LoginWithRefresh(ctx context.Context, provider providers.RefreshableOpenIdP return fmt.Errorf("failed to generate SSH cert: %w", err) } - // Write ssh secret key and public key to filesystem - if seckeyPath != "" { - // If we have set seckeyPath then write it there - if err := writeKeys(seckeyPath, seckeyPath+".pub", seckeySshPem, certBytes); err != nil { - return fmt.Errorf("failed to write SSH keys to filesystem: %w", err) + if _, ok := HasSSHAgent(); ok { + pub, _, _, _, err := ssh.ParseAuthorizedKey(certBytes) + if err != nil { + return err + } + cert, ok := pub.(*ssh.Certificate) + if !ok { + return fmt.Errorf("failed to cast to certificate") + } + + if err := sshagent.Remove(pub); err != nil { + return fmt.Errorf("failed to remove old certificate from ssh-agent: %w", err) + } + + lifetime, err := getLifetime(loginResult.pkt.Payload) + if err != nil { + return fmt.Errorf("failed to parse lifetime from iap in oidc token: %w", err) + } + if err := sshagent.Add(agent.AddedKey{ + PrivateKey: loginResult.signer, + Comment: "openpubkey key", + // We don't remove the old certificate, we include a lifetime + LifetimeSecs: lifetime, + Certificate: cert, + }); err != nil { + fmt.Println("failed to ssh-add key to agent: %w", err) } } else { - // If keyPath isn't set then write it to the default location - if err := writeKeysToSSHDir(seckeySshPem, certBytes); err != nil { - return fmt.Errorf("failed to write SSH keys to filesystem: %w", err) + // Write ssh secret key and public key to filesystem + if seckeyPath != "" { + // If we have set seckeyPath then write it there + if err := writeKeys(seckeyPath, seckeyPath+".pub", seckeySshPem, certBytes); err != nil { + return fmt.Errorf("failed to write SSH keys to filesystem: %w", err) + } + } else { + // If keyPath isn't set then write it to the default location + if err := writeKeysToSSHDir(seckeySshPem, certBytes); err != nil { + return fmt.Errorf("failed to write SSH keys to filesystem: %w", err) + } } } @@ -336,8 +428,9 @@ func LoginWithRefresh(ctx context.Context, provider providers.RefreshableOpenIdP return fmt.Errorf("refreshed ID token payload is not base64 encoded: %w", err) } - if err = json.Unmarshal(payload, &claims); err != nil { - return fmt.Errorf("malformed refreshed ID token payload: %w", err) + untilExpired, err = getExpiration(payload) + if err != nil { + return err } } } @@ -457,14 +550,12 @@ func IdentityString(pkt pktoken.PKToken) (string, error) { } func PrettyIdToken(pkt pktoken.PKToken) (string, error) { - idt, err := oidc.NewJwt(pkt.OpToken) if err != nil { return "", err } idt_json, err := json.MarshalIndent(idt.GetClaims(), "", " ") - if err != nil { return "", err } diff --git a/test/integration/login_test.go b/test/integration/login_test.go index 6835b9d2..6631fa24 100644 --- a/test/integration/login_test.go +++ b/test/integration/login_test.go @@ -50,7 +50,7 @@ func TestLogin(t *testing.T) { opkProvider, loginURL, err := opServer.OpkProvider() require.NoError(t, err, "failed to create OPK provider") go func() { - err := commands.Login(TestCtx, opkProvider, false, "") + err := commands.Login(TestCtx, opkProvider, false, "", nil) errCh <- err }() @@ -119,7 +119,7 @@ func TestLoginCustomKeyPath(t *testing.T) { seckeyPath := filepath.Join(sshPath, "opkssh-key") go func() { - err := commands.Login(TestCtx, opkProvider, false, seckeyPath) + err := commands.Login(TestCtx, opkProvider, false, seckeyPath, nil) errCh <- err }() diff --git a/test/integration/ssh_test.go b/test/integration/ssh_test.go index 4e303d33..5f93a9eb 100644 --- a/test/integration/ssh_test.go +++ b/test/integration/ssh_test.go @@ -42,6 +42,7 @@ import ( "github.com/openpubkey/opkssh/commands" testprovider "github.com/openpubkey/opkssh/test/integration/provider" "github.com/openpubkey/opkssh/test/integration/ssh_server" + "golang.org/x/crypto/ssh/agent" "github.com/melbahja/goph" "github.com/stretchr/testify/assert" @@ -319,7 +320,7 @@ func TestEndToEndSSH(t *testing.T) { errCh := make(chan error) t.Log("------- call login cmd ------") go func() { - err := commands.Login(TestCtx, zitadelOp, false, "") + err := commands.Login(TestCtx, zitadelOp, false, "", nil) errCh <- err }() @@ -414,7 +415,7 @@ func TestEndToEndSSHAsUnprivilegedUser(t *testing.T) { errCh := make(chan error) t.Log("------- call login cmd ------") go func() { - err := commands.Login(TestCtx, zitadelOp, false, "") + err := commands.Login(TestCtx, zitadelOp, false, "", nil) errCh <- err }() @@ -527,7 +528,7 @@ func TestEndToEndSSHWithRefresh(t *testing.T) { defer cancelRefresh() t.Log("------- call login cmd ------") go func() { - err := commands.LoginWithRefresh(refreshCtx, pulseZitadelOp, false, "") + err := commands.LoginWithRefresh(refreshCtx, pulseZitadelOp, false, "", nil) errCh <- err }() @@ -651,3 +652,53 @@ func TestEndToEndSSHWithRefresh(t *testing.T) { require.NoError(t, err) require.Equal(t, serverContainer.User, strings.TrimSpace(string(out))) } + +func TestEndToEndSSHWithAgent(t *testing.T) { + var err error + + // Initializes an in-memory ssh-agent from x/crypto/ssh/agent + keyring := agent.NewKeyring() + + // Ensure we enable ssh-agent in our code + os.Setenv("SSH_AUTH_SOCK", "1") + + oidcContainer, authCallbackRedirectPort, serverContainer := spawnTestContainers(t) + zitadelOp, customTransport := createZitadelOPKSshProvider(oidcContainer.Port, authCallbackRedirectPort) + + errCh := make(chan error) + t.Log("------- call login cmd ------") + go func() { + err := commands.Login(TestCtx, zitadelOp, false, "", keyring) + errCh <- err + }() + + timeoutErr := WaitForServer(TestCtx, fmt.Sprintf("http://localhost:%d", authCallbackRedirectPort), LoginCallbackServerTimeout) + require.NoError(t, timeoutErr, "login callback server took too long to startup") + + DoOidcInteractiveLogin(t, customTransport, fmt.Sprintf("http://localhost:%d/login", authCallbackRedirectPort), "test-user@oidc.local", "verysecure") + + timeoutCtx, cancel := context.WithTimeout(TestCtx, 3*time.Second) + defer cancel() + select { + case loginErr := <-errCh: + require.NoError(t, loginErr, "failed login") + case <-timeoutCtx.Done(): + t.Fatal(timeoutCtx.Err()) + } + + authKey := goph.Auth{ssh.PublicKeysCallback(keyring.Signers)} + opkSshClient, err := goph.NewConn(&goph.Config{ + User: serverContainer.User, + Addr: serverContainer.Host, + Port: uint(serverContainer.Port), + Auth: authKey, + Timeout: goph.DefaultTimeout, + Callback: ssh.InsecureIgnoreHostKey(), + }) + require.NoError(t, err) + defer opkSshClient.Close() + + out, err := opkSshClient.Run("whoami") + require.NoError(t, err) + require.Equal(t, serverContainer.User, strings.TrimSpace(string(out))) +}