Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 130 additions & 39 deletions commands/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"fmt"
"io"
"log"
"net"
"os"
"path/filepath"
"strings"
Expand All @@ -41,6 +42,7 @@ import (
"github.com/openpubkey/openpubkey/util"
"github.com/openpubkey/opkssh/sshcert"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)

type LoginCmd struct {
Expand All @@ -58,6 +60,30 @@ type LoginCmd struct {
principals []string
}

func HasSSHAgent() (string, bool) {
return os.LookupEnv("SSH_AUTH_SOCK")
}

func getLifetime(b []byte) (uint32, error) {
var claims struct {
IssuedAt int64 `json:"iat"`
}
if err := json.Unmarshal(b, &claims); err != nil {
return 0, fmt.Errorf("malformed refreshed ID token payload: %w", err)
}
return uint32(time.Duration(time.Until(time.Unix(claims.IssuedAt, 0).Add(23 * time.Hour)).Seconds())), nil
}

func getExpiration(b []byte) (time.Duration, error) {
var claims struct {
Expiration int64 `json:"exp"`
}
if err := json.Unmarshal(b, &claims); err != nil {
return 0, fmt.Errorf("malformed refreshed ID token payload: %w", err)
}
return time.Until(time.Unix(claims.Expiration, 0)) - time.Minute, nil
}

func NewLogin(autoRefresh bool, logDir string, disableBrowserOpenArg bool, printIdTokenArg bool, providerArg string, keyPathArg string, providerFromLdFlags providers.OpenIdProvider) *LoginCmd {
return &LoginCmd{
autoRefresh: autoRefresh,
Expand Down Expand Up @@ -180,27 +206,38 @@ func (l *LoginCmd) Run(ctx context.Context) error {
}
}

var sshagent agent.Agent
if s, ok := HasSSHAgent(); ok {
conn, err := net.Dial("unix", s)
if err != nil {
return fmt.Errorf("failed to connect to ssh-agent socket: %w", err)
}
defer conn.Close()
sshagent = agent.NewClient(conn)
}

// Execute login command
if l.autoRefresh {
if providerRefreshable, ok := provider.(providers.RefreshableOpenIdProvider); ok {
err := LoginWithRefresh(ctx, providerRefreshable, l.printIdTokenArg, l.keyPathArg)
err := LoginWithRefresh(ctx, providerRefreshable, l.printIdTokenArg, l.keyPathArg, sshagent)
if err != nil {
return fmt.Errorf("error logging in: %w", err)
}
} else {
return fmt.Errorf("supplied OpenID Provider (%v) does not support auto-refresh and auto-refresh argument set to true", provider.Issuer())
}
} else {
err := Login(ctx, provider, l.printIdTokenArg, l.keyPathArg)
err := Login(ctx, provider, l.printIdTokenArg, l.keyPathArg, sshagent)
if err != nil {
return fmt.Errorf("error logging in: %w", err)
}
}
return nil
}

func login(ctx context.Context, provider client.OpenIdProvider, printIdToken bool, seckeyPath string) (*LoginCmd, error) {
func login(ctx context.Context, provider client.OpenIdProvider, printIdToken bool, seckeyPath string, sshagent agent.Agent) (*LoginCmd, error) {
var err error

alg := jwa.ES256
signer, err := util.GenKeyPair(alg)
if err != nil {
Expand All @@ -225,22 +262,8 @@ func login(ctx context.Context, provider client.OpenIdProvider, printIdToken boo
return nil, fmt.Errorf("failed to generate SSH cert: %w", err)
}

// Write ssh secret key and public key to filesystem
if seckeyPath != "" {
// If we have set seckeyPath then write it there
if err := writeKeys(seckeyPath, seckeyPath+".pub", seckeySshPem, certBytes); err != nil {
return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
}
} else {
// If keyPath isn't set then write it to the default location
if err := writeKeysToSSHDir(seckeySshPem, certBytes); err != nil {
return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
}
}

if printIdToken {
idTokenStr, err := PrettyIdToken(*pkt)

if err != nil {
return nil, fmt.Errorf("failed to format ID Token: %w", err)
}
Expand All @@ -252,7 +275,50 @@ func login(ctx context.Context, provider client.OpenIdProvider, printIdToken boo
if err != nil {
return nil, fmt.Errorf("failed to parse ID Token: %w", err)
}
fmt.Printf("Keys generated for identity\n%s\n", idStr)

if _, ok := HasSSHAgent(); ok {
pubkey, _, _, _, err := ssh.ParseAuthorizedKey(certBytes)
if err != nil {
return nil, err
}
cert, ok := pubkey.(*ssh.Certificate)
if !ok {
return nil, fmt.Errorf("failed to cast to certificate")
}

lifetime, err := getLifetime(pkt.Payload)
if err != nil {
return nil, err
}

if err := sshagent.Add(agent.AddedKey{
PrivateKey: signer,
Comment: "openpubkey key",
LifetimeSecs: lifetime,
Certificate: cert,
}); err != nil {
Comment thread
Foxboron marked this conversation as resolved.
fmt.Println("failed to ssh-add key to agent: %w", err)
}
fmt.Printf("Keys generated and added to ssh-agent for identity\n%s\n", idStr)
} else {
// Write ssh secret key and public key to filesystem
if seckeyPath != "" {
// If we have set seckeyPath then write it there
if err := writeKeys(seckeyPath, seckeyPath+".pub", seckeySshPem, certBytes); err != nil {
return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
}
} else {
// If keyPath isn't set then write it to the default location
if err := writeKeysToSSHDir(seckeySshPem, certBytes); err != nil {
return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
}
// Write ssh secret key and public key to filesystem
if err := writeKeysToSSHDir(seckeySshPem, certBytes); err != nil {
return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
}
fmt.Printf("Keys generated for identity\n%s\n", idStr)
}
}

return &LoginCmd{
pkt: pkt,
Expand All @@ -265,8 +331,8 @@ func login(ctx context.Context, provider client.OpenIdProvider, printIdToken boo

// Login performs the OIDC login procedure and creates the SSH certs/keys in the
// default SSH key location.
func Login(ctx context.Context, provider client.OpenIdProvider, printIdToken bool, seckeyPath string) error {
_, err := login(ctx, provider, printIdToken, seckeyPath)
func Login(ctx context.Context, provider client.OpenIdProvider, printIdToken bool, seckeyPath string, agent agent.Agent) error {
_, err := login(ctx, provider, printIdToken, seckeyPath, agent)
return err
}

Expand All @@ -275,21 +341,18 @@ func Login(ctx context.Context, provider client.OpenIdProvider, printIdToken boo
// the PKT (and create new SSH certs) indefinitely as its token expires. This
// function only returns if it encounters an error or if the supplied context is
// cancelled.
func LoginWithRefresh(ctx context.Context, provider providers.RefreshableOpenIdProvider, printIdToken bool, seckeyPath string) error {
if loginResult, err := login(ctx, provider, printIdToken, seckeyPath); err != nil {
func LoginWithRefresh(ctx context.Context, provider providers.RefreshableOpenIdProvider, printIdToken bool, seckeyPath string, sshagent agent.Agent) error {
if loginResult, err := login(ctx, provider, printIdToken, seckeyPath, sshagent); err != nil {
return err
} else {
var claims struct {
Expiration int64 `json:"exp"`
}
if err := json.Unmarshal(loginResult.pkt.Payload, &claims); err != nil {
untilExpired, err := getExpiration(loginResult.pkt.Payload)
if err != nil {
return err
}

for {
// Sleep until a minute before expiration to give us time to refresh
// the token and minimize any interruptions
untilExpired := time.Until(time.Unix(claims.Expiration, 0)) - time.Minute
log.Printf("Waiting for %v before attempting to refresh id_token...", untilExpired)
select {
case <-time.After(untilExpired):
Expand All @@ -309,16 +372,45 @@ func LoginWithRefresh(ctx context.Context, provider providers.RefreshableOpenIdP
return fmt.Errorf("failed to generate SSH cert: %w", err)
}

// Write ssh secret key and public key to filesystem
if seckeyPath != "" {
// If we have set seckeyPath then write it there
if err := writeKeys(seckeyPath, seckeyPath+".pub", seckeySshPem, certBytes); err != nil {
return fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
if _, ok := HasSSHAgent(); ok {
pub, _, _, _, err := ssh.ParseAuthorizedKey(certBytes)
if err != nil {
return err
}
cert, ok := pub.(*ssh.Certificate)
if !ok {
return fmt.Errorf("failed to cast to certificate")
}

if err := sshagent.Remove(pub); err != nil {
return fmt.Errorf("failed to remove old certificate from ssh-agent: %w", err)
}

lifetime, err := getLifetime(loginResult.pkt.Payload)
if err != nil {
return fmt.Errorf("failed to parse lifetime from iap in oidc token: %w", err)
}
if err := sshagent.Add(agent.AddedKey{
PrivateKey: loginResult.signer,
Comment: "openpubkey key",
// We don't remove the old certificate, we include a lifetime
LifetimeSecs: lifetime,
Certificate: cert,
}); err != nil {
fmt.Println("failed to ssh-add key to agent: %w", err)
}
} else {
// If keyPath isn't set then write it to the default location
if err := writeKeysToSSHDir(seckeySshPem, certBytes); err != nil {
return fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
// Write ssh secret key and public key to filesystem
if seckeyPath != "" {
// If we have set seckeyPath then write it there
if err := writeKeys(seckeyPath, seckeyPath+".pub", seckeySshPem, certBytes); err != nil {
return fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
}
} else {
// If keyPath isn't set then write it to the default location
if err := writeKeysToSSHDir(seckeySshPem, certBytes); err != nil {
return fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
}
}
}

Expand All @@ -336,8 +428,9 @@ func LoginWithRefresh(ctx context.Context, provider providers.RefreshableOpenIdP
return fmt.Errorf("refreshed ID token payload is not base64 encoded: %w", err)
}

if err = json.Unmarshal(payload, &claims); err != nil {
return fmt.Errorf("malformed refreshed ID token payload: %w", err)
untilExpired, err = getExpiration(payload)
if err != nil {
return err
}
}
}
Expand Down Expand Up @@ -457,14 +550,12 @@ func IdentityString(pkt pktoken.PKToken) (string, error) {
}

func PrettyIdToken(pkt pktoken.PKToken) (string, error) {

idt, err := oidc.NewJwt(pkt.OpToken)
if err != nil {
return "", err
}

idt_json, err := json.MarshalIndent(idt.GetClaims(), "", " ")

if err != nil {
return "", err
}
Expand Down
4 changes: 2 additions & 2 deletions test/integration/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestLogin(t *testing.T) {
opkProvider, loginURL, err := opServer.OpkProvider()
require.NoError(t, err, "failed to create OPK provider")
go func() {
err := commands.Login(TestCtx, opkProvider, false, "")
err := commands.Login(TestCtx, opkProvider, false, "", nil)
errCh <- err
}()

Expand Down Expand Up @@ -119,7 +119,7 @@ func TestLoginCustomKeyPath(t *testing.T) {
seckeyPath := filepath.Join(sshPath, "opkssh-key")

go func() {
err := commands.Login(TestCtx, opkProvider, false, seckeyPath)
err := commands.Login(TestCtx, opkProvider, false, seckeyPath, nil)
errCh <- err
}()

Expand Down
57 changes: 54 additions & 3 deletions test/integration/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
"github.com/openpubkey/opkssh/commands"
testprovider "github.com/openpubkey/opkssh/test/integration/provider"
"github.com/openpubkey/opkssh/test/integration/ssh_server"
"golang.org/x/crypto/ssh/agent"

"github.com/melbahja/goph"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -319,7 +320,7 @@ func TestEndToEndSSH(t *testing.T) {
errCh := make(chan error)
t.Log("------- call login cmd ------")
go func() {
err := commands.Login(TestCtx, zitadelOp, false, "")
err := commands.Login(TestCtx, zitadelOp, false, "", nil)
errCh <- err
}()

Expand Down Expand Up @@ -414,7 +415,7 @@ func TestEndToEndSSHAsUnprivilegedUser(t *testing.T) {
errCh := make(chan error)
t.Log("------- call login cmd ------")
go func() {
err := commands.Login(TestCtx, zitadelOp, false, "")
err := commands.Login(TestCtx, zitadelOp, false, "", nil)
errCh <- err
}()

Expand Down Expand Up @@ -527,7 +528,7 @@ func TestEndToEndSSHWithRefresh(t *testing.T) {
defer cancelRefresh()
t.Log("------- call login cmd ------")
go func() {
err := commands.LoginWithRefresh(refreshCtx, pulseZitadelOp, false, "")
err := commands.LoginWithRefresh(refreshCtx, pulseZitadelOp, false, "", nil)
errCh <- err
}()

Expand Down Expand Up @@ -651,3 +652,53 @@ func TestEndToEndSSHWithRefresh(t *testing.T) {
require.NoError(t, err)
require.Equal(t, serverContainer.User, strings.TrimSpace(string(out)))
}

func TestEndToEndSSHWithAgent(t *testing.T) {
var err error

// Initializes an in-memory ssh-agent from x/crypto/ssh/agent
keyring := agent.NewKeyring()

// Ensure we enable ssh-agent in our code
os.Setenv("SSH_AUTH_SOCK", "1")

oidcContainer, authCallbackRedirectPort, serverContainer := spawnTestContainers(t)
zitadelOp, customTransport := createZitadelOPKSshProvider(oidcContainer.Port, authCallbackRedirectPort)

errCh := make(chan error)
t.Log("------- call login cmd ------")
go func() {
err := commands.Login(TestCtx, zitadelOp, false, "", keyring)
errCh <- err
}()

timeoutErr := WaitForServer(TestCtx, fmt.Sprintf("http://localhost:%d", authCallbackRedirectPort), LoginCallbackServerTimeout)
require.NoError(t, timeoutErr, "login callback server took too long to startup")

DoOidcInteractiveLogin(t, customTransport, fmt.Sprintf("http://localhost:%d/login", authCallbackRedirectPort), "test-user@oidc.local", "verysecure")

timeoutCtx, cancel := context.WithTimeout(TestCtx, 3*time.Second)
defer cancel()
select {
case loginErr := <-errCh:
require.NoError(t, loginErr, "failed login")
case <-timeoutCtx.Done():
t.Fatal(timeoutCtx.Err())
}

authKey := goph.Auth{ssh.PublicKeysCallback(keyring.Signers)}
opkSshClient, err := goph.NewConn(&goph.Config{
User: serverContainer.User,
Addr: serverContainer.Host,
Port: uint(serverContainer.Port),
Auth: authKey,
Timeout: goph.DefaultTimeout,
Callback: ssh.InsecureIgnoreHostKey(),
})
require.NoError(t, err)
defer opkSshClient.Close()

out, err := opkSshClient.Run("whoami")
require.NoError(t, err)
require.Equal(t, serverContainer.User, strings.TrimSpace(string(out)))
}
Loading