From f669dc4b57582c97a07eb48ac7617d7d23fe950b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 22 Jul 2021 18:53:33 +0000 Subject: [PATCH 01/12] build(deps): bump helm/kind-action from 1.1.0 to 1.2.0 Bumps [helm/kind-action](https://github.com/helm/kind-action) from 1.1.0 to 1.2.0. - [Release notes](https://github.com/helm/kind-action/releases) - [Commits](https://github.com/helm/kind-action/compare/v1.1.0...v1.2.0) --- updated-dependencies: - dependency-name: helm/kind-action dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 320158745e..8c6ee6c6f0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -58,7 +58,7 @@ jobs: run: docker-compose -f docker-compose.test.yaml up -d - name: Create kind cluster - uses: helm/kind-action@v1.1.0 + uses: helm/kind-action@v1.2.0 with: version: v0.11.1 node_image: kindest/node:v1.19.11@sha256:07db187ae84b4b7de440a73886f008cf903fcf5764ba8106a9fd5243d6f32729 From bf3b81ca3b22d71b818578254baf20d32bcd3660 Mon Sep 17 00:00:00 2001 From: acohen4 Date: Thu, 22 Jul 2021 11:59:16 -0700 Subject: [PATCH 02/12] Adding an oidcConnector config to decide whether to validate that the Connector Callback must contain the same host as the issuer --- connector/oidc/oidc.go | 63 +++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index b752f9dac5..dfea950244 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -44,6 +44,9 @@ type Config struct { // InsecureEnableGroups enables groups claims. This is disabled by default until https://github.com/dexidp/dex/issues/1065 is resolved InsecureEnableGroups bool `json:"insecureEnableGroups"` + // Skips checking whether the requested domain in the Login Callback matches the configured Issuer + InsecureSkipIssuerCallbackDomainCheck bool `json:"insecureSkipIssuerCallbackDomainCheck"` + // GetUserInfo uses the userinfo endpoint to get additional claims for // the token. This is especially useful where upstreams return "thin" // id tokens @@ -144,18 +147,19 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e verifier: provider.Verifier( &oidc.Config{ClientID: clientID}, ), - logger: logger, - cancel: cancel, - hostedDomains: c.HostedDomains, - insecureSkipEmailVerified: c.InsecureSkipEmailVerified, - insecureEnableGroups: c.InsecureEnableGroups, - getUserInfo: c.GetUserInfo, - promptType: c.PromptType, - userIDKey: c.UserIDKey, - userNameKey: c.UserNameKey, - preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey, - emailKey: c.ClaimMapping.EmailKey, - groupsKey: c.ClaimMapping.GroupsKey, + logger: logger, + cancel: cancel, + hostedDomains: c.HostedDomains, + insecureSkipEmailVerified: c.InsecureSkipEmailVerified, + insecureEnableGroups: c.InsecureEnableGroups, + insecureSkipIssuerCallbackDomainCheck: c.InsecureSkipIssuerCallbackDomainCheck, + getUserInfo: c.GetUserInfo, + promptType: c.PromptType, + userIDKey: c.UserIDKey, + userNameKey: c.UserNameKey, + preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey, + emailKey: c.ClaimMapping.EmailKey, + groupsKey: c.ClaimMapping.GroupsKey, }, nil } @@ -165,22 +169,23 @@ var ( ) type oidcConnector struct { - provider *oidc.Provider - redirectURI string - oauth2Config *oauth2.Config - verifier *oidc.IDTokenVerifier - cancel context.CancelFunc - logger log.Logger - hostedDomains []string - insecureSkipEmailVerified bool - insecureEnableGroups bool - getUserInfo bool - promptType string - userIDKey string - userNameKey string - preferredUsernameKey string - emailKey string - groupsKey string + provider *oidc.Provider + redirectURI string + oauth2Config *oauth2.Config + verifier *oidc.IDTokenVerifier + cancel context.CancelFunc + logger log.Logger + hostedDomains []string + insecureSkipEmailVerified bool + insecureEnableGroups bool + insecureSkipIssuerCallbackDomainCheck bool + getUserInfo bool + promptType string + userIDKey string + userNameKey string + preferredUsernameKey string + emailKey string + groupsKey string } func (c *oidcConnector) Close() error { @@ -189,7 +194,7 @@ func (c *oidcConnector) Close() error { } func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { - if c.redirectURI != callbackURL { + if c.redirectURI != callbackURL && !c.insecureSkipIssuerCallbackDomainCheck { return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } From 56fc504b721fbc12b26d91c556192637a7e9fe9a Mon Sep 17 00:00:00 2001 From: Alon Cohen Date: Wed, 11 Aug 2021 14:23:33 -0400 Subject: [PATCH 03/12] pass login_hint to OIDC connector's /auth endpoint in auth_code flow (#12) * pass login_hint to OIDC connector's /auth endpoint in auth_code flow * Also pass non-OIDC Standard params that may be used by specific connectors --- connector/authproxy/authproxy.go | 2 +- connector/bitbucketcloud/bitbucketcloud.go | 3 +- connector/connector.go | 3 +- connector/gitea/gitea.go | 3 +- connector/github/github.go | 3 +- connector/gitlab/gitlab.go | 3 +- connector/google/google.go | 3 +- connector/linkedin/linkedin.go | 3 +- connector/microsoft/microsoft.go | 3 +- connector/mock/connectortest.go | 2 +- connector/oidc/oidc.go | 21 +++- connector/openshift/openshift.go | 3 +- server/handlers.go | 4 +- server/oauth2.go | 43 ++++---- server/oauth2_test.go | 2 +- storage/ent/db/authrequest.go | 12 ++- storage/ent/db/authrequest/authrequest.go | 5 + storage/ent/db/authrequest/where.go | 118 +++++++++++++++++++++ storage/ent/db/authrequest_create.go | 29 +++++ storage/ent/db/authrequest_update.go | 42 ++++++++ storage/ent/db/migrate/schema.go | 1 + storage/ent/db/mutation.go | 56 +++++++++- storage/ent/db/runtime.go | 4 + storage/ent/schema/authrequest.go | 3 + storage/sql/crud.go | 16 +-- storage/sql/migrate.go | 7 ++ storage/storage.go | 1 + 27 files changed, 351 insertions(+), 44 deletions(-) diff --git a/connector/authproxy/authproxy.go b/connector/authproxy/authproxy.go index 853e5ee29f..221058c9e1 100644 --- a/connector/authproxy/authproxy.go +++ b/connector/authproxy/authproxy.go @@ -37,7 +37,7 @@ type callback struct { } // LoginURL returns the URL to redirect the user to login with. -func (m *callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (m *callback) LoginURL(s connector.Scopes, callbackURL, state string, _ url.Values) (string, error) { u, err := url.Parse(callbackURL) if err != nil { return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) diff --git a/connector/bitbucketcloud/bitbucketcloud.go b/connector/bitbucketcloud/bitbucketcloud.go index e81893da07..51113d8bfa 100644 --- a/connector/bitbucketcloud/bitbucketcloud.go +++ b/connector/bitbucketcloud/bitbucketcloud.go @@ -8,6 +8,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/url" "sync" "time" @@ -111,7 +112,7 @@ func (b *bitbucketConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi } } -func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string, _ url.Values) (string, error) { if b.redirectURI != callbackURL { return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, b.redirectURI) } diff --git a/connector/connector.go b/connector/connector.go index aab994b468..0f34ebc88a 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -4,6 +4,7 @@ package connector import ( "context" "net/http" + "net/url" ) // Connector is a mechanism for federating login to a remote identity service. @@ -63,7 +64,7 @@ type CallbackConnector interface { // requested if one has already been issues. There's no good general answer // for these kind of restrictions, and may require this package to become more // aware of the global set of user/connector interactions. - LoginURL(s Scopes, callbackURL, state string) (string, error) + LoginURL(s Scopes, callbackURL, state string, forwardedParams url.Values) (string, error) // Handle the callback to the server and return an identity. HandleCallback(s Scopes, r *http.Request) (identity Identity, err error) diff --git a/connector/gitea/gitea.go b/connector/gitea/gitea.go index 33cc3126e6..90385c6c72 100644 --- a/connector/gitea/gitea.go +++ b/connector/gitea/gitea.go @@ -8,6 +8,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/url" "strconv" "sync" "time" @@ -82,7 +83,7 @@ func (c *giteaConnector) oauth2Config(_ connector.Scopes) *oauth2.Config { } } -func (c *giteaConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *giteaConnector) LoginURL(scopes connector.Scopes, callbackURL, state string, _ url.Values) (string, error) { if c.redirectURI != callbackURL { return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) } diff --git a/connector/github/github.go b/connector/github/github.go index 02f2cae804..3df415d322 100644 --- a/connector/github/github.go +++ b/connector/github/github.go @@ -11,6 +11,7 @@ import ( "io/ioutil" "net" "net/http" + "net/url" "regexp" "strconv" "strings" @@ -187,7 +188,7 @@ func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { } } -func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string, _ url.Values) (string, error) { if c.redirectURI != callbackURL { return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } diff --git a/connector/gitlab/gitlab.go b/connector/gitlab/gitlab.go index e40601402d..648a825388 100644 --- a/connector/gitlab/gitlab.go +++ b/connector/gitlab/gitlab.go @@ -8,6 +8,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/url" "strconv" "golang.org/x/oauth2" @@ -98,7 +99,7 @@ func (c *gitlabConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { } } -func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string, _ url.Values) (string, error) { if c.redirectURI != callbackURL { return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) } diff --git a/connector/google/google.go b/connector/google/google.go index eccb1fc7d7..594181c59e 100644 --- a/connector/google/google.go +++ b/connector/google/google.go @@ -7,6 +7,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/url" "time" "github.com/coreos/go-oidc/v3/oidc" @@ -120,7 +121,7 @@ func (c *googleConnector) Close() error { return nil } -func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string, _ url.Values) (string, error) { if c.redirectURI != callbackURL { return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } diff --git a/connector/linkedin/linkedin.go b/connector/linkedin/linkedin.go index 1c8312c11e..c3c63ff35e 100644 --- a/connector/linkedin/linkedin.go +++ b/connector/linkedin/linkedin.go @@ -7,6 +7,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/url" "strings" "golang.org/x/oauth2" @@ -62,7 +63,7 @@ var ( ) // LoginURL returns an access token request URL -func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string, _ url.Values) (string, error) { if c.oauth2Config.RedirectURL != callbackURL { return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.oauth2Config.RedirectURL) diff --git a/connector/microsoft/microsoft.go b/connector/microsoft/microsoft.go index 328ea15274..2195f74451 100644 --- a/connector/microsoft/microsoft.go +++ b/connector/microsoft/microsoft.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" "sync" "time" @@ -151,7 +152,7 @@ func (c *microsoftConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi } } -func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string, _ url.Values) (string, error) { if c.redirectURI != callbackURL { return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } diff --git a/connector/mock/connectortest.go b/connector/mock/connectortest.go index 5089914ca1..b35ac5519d 100644 --- a/connector/mock/connectortest.go +++ b/connector/mock/connectortest.go @@ -43,7 +43,7 @@ type Callback struct { } // LoginURL returns the URL to redirect the user to login with. -func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string, _ url.Values) (string, error) { u, err := url.Parse(callbackURL) if err != nil { return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index dfea950244..ec00814b19 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -59,6 +59,8 @@ type Config struct { // PromptType will be used fot the prompt parameter (when offline_access, by default prompt=consent) PromptType string `json:"promptType"` + ForwardedLoginParams []string `json:"forwardedLoginParams"` + ClaimMapping struct { // Configurable key which contains the preferred username claims PreferredUsernameKey string `json:"preferred_username"` // defaults to "preferred_username" @@ -160,6 +162,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey, emailKey: c.ClaimMapping.EmailKey, groupsKey: c.ClaimMapping.GroupsKey, + forwardedLoginParams: c.ForwardedLoginParams, }, nil } @@ -186,6 +189,7 @@ type oidcConnector struct { preferredUsernameKey string emailKey string groupsKey string + forwardedLoginParams []string } func (c *oidcConnector) Close() error { @@ -193,7 +197,7 @@ func (c *oidcConnector) Close() error { return nil } -func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string, forwardParams url.Values) (string, error) { if c.redirectURI != callbackURL && !c.insecureSkipIssuerCallbackDomainCheck { return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } @@ -210,6 +214,21 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) if s.OfflineAccess { opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType)) } + + for p := range forwardParams { + paramApproved := false + for _, approvedParam := range c.forwardedLoginParams { + if p == approvedParam { + paramApproved = true + break + } + } + if paramApproved { + opts = append(opts, oauth2.SetAuthURLParam(p, forwardParams.Get(p))) + } else { + c.logger.Infof("oidc: auth query parameter %s, is unapproved and was not forwarded to the connector idp", p) + } + } return c.oauth2Config.AuthCodeURL(state, opts...), nil } diff --git a/connector/openshift/openshift.go b/connector/openshift/openshift.go index f06e8f8045..1469e55609 100644 --- a/connector/openshift/openshift.go +++ b/connector/openshift/openshift.go @@ -9,6 +9,7 @@ import ( "io/ioutil" "net" "net/http" + "net/url" "strings" "time" @@ -117,7 +118,7 @@ func (c *openshiftConnector) Close() error { } // LoginURL returns the URL to redirect the user to login with. -func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string, _ url.Values) (string, error) { if c.redirectURI != callbackURL { return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } diff --git a/server/handlers.go b/server/handlers.go index 4144dd1f87..fda32fe8c5 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -184,7 +184,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { - authReq, err := s.parseAuthorizationRequest(r) + authReq, forwardParams, err := s.parseAuthorizationRequest(r) if err != nil { s.logger.Errorf("Failed to parse authorization request: %v", err) status := http.StatusInternalServerError @@ -250,7 +250,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { // Use the auth request ID as the "state" token. // // TODO(ericchiang): Is this appropriate or should we also be using a nonce? - callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID) + callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID, forwardParams) if err != nil { s.logger.Errorf("Connector %q returned error when creating callback: %v", connID, err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") diff --git a/server/oauth2.go b/server/oauth2.go index 00beb6ff4b..4e9b612b7c 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -406,14 +406,14 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str } // parse the initial request from the OAuth2 client. -func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthRequest, error) { +func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthRequest, url.Values, error) { if err := r.ParseForm(); err != nil { - return nil, &authErr{"", "", errInvalidRequest, "Failed to parse request body."} + return nil, nil, &authErr{"", "", errInvalidRequest, "Failed to parse request body."} } q := r.Form redirectURI, err := url.QueryUnescape(q.Get("redirect_uri")) if err != nil { - return nil, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."} + return nil, nil, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."} } clientID := q.Get("client_id") @@ -435,25 +435,25 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques if err != nil { if err == storage.ErrNotFound { description := fmt.Sprintf("Invalid client_id (%q).", clientID) - return nil, &authErr{"", "", errUnauthorizedClient, description} + return nil, nil, &authErr{"", "", errUnauthorizedClient, description} } s.logger.Errorf("Failed to get client: %v", err) - return nil, &authErr{"", "", errServerError, ""} + return nil, nil, &authErr{"", "", errServerError, ""} } if connectorID != "" { connectors, err := s.storage.ListConnectors() if err != nil { - return nil, &authErr{"", "", errServerError, "Unable to retrieve connectors"} + return nil, nil, &authErr{"", "", errServerError, "Unable to retrieve connectors"} } if !validateConnectorID(connectors, connectorID) { - return nil, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"} + return nil, nil, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"} } } if !validateRedirectURI(client, redirectURI) { description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI) - return nil, &authErr{"", "", errInvalidRequest, description} + return nil, nil, &authErr{"", "", errInvalidRequest, description} } if redirectURI == deviceCallbackURI && client.Public { redirectURI = s.issuerURL.Path + deviceCallbackURI @@ -467,12 +467,12 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques // dex doesn't support request parameter and must return request_not_supported error // https://openid.net/specs/openid-connect-core-1_0.html#6.1 if q.Get("request") != "" { - return nil, newErr(errRequestNotSupported, "Server does not support request parameter.") + return nil, nil, newErr(errRequestNotSupported, "Server does not support request parameter.") } if codeChallengeMethod != codeChallengeMethodS256 && codeChallengeMethod != codeChallengeMethodPlain { description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod) - return nil, newErr(errInvalidRequest, description) + return nil, nil, newErr(errInvalidRequest, description) } var ( @@ -494,7 +494,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques isTrusted, err := s.validateCrossClientTrust(clientID, peerID) if err != nil { - return nil, newErr(errServerError, "Internal server error.") + return nil, nil, newErr(errServerError, "Internal server error.") } if !isTrusted { invalidScopes = append(invalidScopes, scope) @@ -502,13 +502,13 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques } } if !hasOpenIDScope { - return nil, newErr(errInvalidScope, `Missing required scope(s) ["openid"].`) + return nil, nil, newErr(errInvalidScope, `Missing required scope(s) ["openid"].`) } if len(unrecognized) > 0 { - return nil, newErr(errInvalidScope, "Unrecognized scope(s) %q", unrecognized) + return nil, nil, newErr(errInvalidScope, "Unrecognized scope(s) %q", unrecognized) } if len(invalidScopes) > 0 { - return nil, newErr(errInvalidScope, "Client can't request scope(s) %q", invalidScopes) + return nil, nil, newErr(errInvalidScope, "Client can't request scope(s) %q", invalidScopes) } var rt struct { @@ -526,23 +526,23 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques case responseTypeToken: rt.token = true default: - return nil, newErr(errInvalidRequest, "Invalid response type %q", responseType) + return nil, nil, newErr(errInvalidRequest, "Invalid response type %q", responseType) } if !s.supportedResponseTypes[responseType] { - return nil, newErr(errUnsupportedResponseType, "Unsupported response type %q", responseType) + return nil, nil, newErr(errUnsupportedResponseType, "Unsupported response type %q", responseType) } } if len(responseTypes) == 0 { - return nil, newErr(errInvalidRequest, "No response_type provided") + return nil, nil, newErr(errInvalidRequest, "No response_type provided") } if rt.token && !rt.code && !rt.idToken { // "token" can't be provided by its own. // // https://openid.net/specs/openid-connect-core-1_0.html#Authentication - return nil, newErr(errInvalidRequest, "Response type 'token' must be provided with type 'id_token' and/or 'code'") + return nil, nil, newErr(errInvalidRequest, "Response type 'token' must be provided with type 'id_token' and/or 'code'") } if !rt.code { // Either "id_token token" or "id_token" has been provided which implies the @@ -550,13 +550,13 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques // // https://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthRequest if nonce == "" { - return nil, newErr(errInvalidRequest, "Response type 'token' requires a 'nonce' value.") + return nil, nil, newErr(errInvalidRequest, "Response type 'token' requires a 'nonce' value.") } } if rt.token { if redirectURI == redirectURIOOB { err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB) - return nil, newErr(errInvalidRequest, err) + return nil, nil, newErr(errInvalidRequest, err) } } @@ -570,11 +570,12 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques RedirectURI: redirectURI, ResponseTypes: responseTypes, ConnectorID: connectorID, + LoginHint: q.Get("login_hint"), PKCE: storage.PKCE{ CodeChallenge: codeChallenge, CodeChallengeMethod: codeChallengeMethod, }, - }, nil + }, r.Form, nil } func parseCrossClientScope(scope string) (peerID string, ok bool) { diff --git a/server/oauth2_test.go b/server/oauth2_test.go index 518e22ee86..6c3650add2 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -320,7 +320,7 @@ func TestParseAuthorizationRequest(t *testing.T) { req = httptest.NewRequest("GET", httpServer.URL+"/auth?"+params.Encode(), nil) } - _, err := server.parseAuthorizationRequest(req) + _, _, err := server.parseAuthorizationRequest(req) if tc.wantErr { require.Error(t, err) if tc.exactError != nil { diff --git a/storage/ent/db/authrequest.go b/storage/ent/db/authrequest.go index ed64d9f679..ad8be7f3b4 100644 --- a/storage/ent/db/authrequest.go +++ b/storage/ent/db/authrequest.go @@ -55,6 +55,8 @@ type AuthRequest struct { CodeChallenge string `json:"code_challenge,omitempty"` // CodeChallengeMethod holds the value of the "code_challenge_method" field. CodeChallengeMethod string `json:"code_challenge_method,omitempty"` + // LoginHint holds the value of the "login_hint" field. + LoginHint string `json:"login_hint,omitempty"` } // scanValues returns the types for scanning values from sql.Rows. @@ -66,7 +68,7 @@ func (*AuthRequest) scanValues(columns []string) ([]interface{}, error) { values[i] = new([]byte) case authrequest.FieldForceApprovalPrompt, authrequest.FieldLoggedIn, authrequest.FieldClaimsEmailVerified: values[i] = new(sql.NullBool) - case authrequest.FieldID, authrequest.FieldClientID, authrequest.FieldRedirectURI, authrequest.FieldNonce, authrequest.FieldState, authrequest.FieldClaimsUserID, authrequest.FieldClaimsUsername, authrequest.FieldClaimsEmail, authrequest.FieldClaimsPreferredUsername, authrequest.FieldConnectorID, authrequest.FieldCodeChallenge, authrequest.FieldCodeChallengeMethod: + case authrequest.FieldID, authrequest.FieldClientID, authrequest.FieldRedirectURI, authrequest.FieldNonce, authrequest.FieldState, authrequest.FieldClaimsUserID, authrequest.FieldClaimsUsername, authrequest.FieldClaimsEmail, authrequest.FieldClaimsPreferredUsername, authrequest.FieldConnectorID, authrequest.FieldCodeChallenge, authrequest.FieldCodeChallengeMethod, authrequest.FieldLoginHint: values[i] = new(sql.NullString) case authrequest.FieldExpiry: values[i] = new(sql.NullTime) @@ -214,6 +216,12 @@ func (ar *AuthRequest) assignValues(columns []string, values []interface{}) erro } else if value.Valid { ar.CodeChallengeMethod = value.String } + case authrequest.FieldLoginHint: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field login_hint", values[i]) + } else if value.Valid { + ar.LoginHint = value.String + } } } return nil @@ -282,6 +290,8 @@ func (ar *AuthRequest) String() string { builder.WriteString(ar.CodeChallenge) builder.WriteString(", code_challenge_method=") builder.WriteString(ar.CodeChallengeMethod) + builder.WriteString(", login_hint=") + builder.WriteString(ar.LoginHint) builder.WriteByte(')') return builder.String() } diff --git a/storage/ent/db/authrequest/authrequest.go b/storage/ent/db/authrequest/authrequest.go index 00094a7372..7e872acb8f 100644 --- a/storage/ent/db/authrequest/authrequest.go +++ b/storage/ent/db/authrequest/authrequest.go @@ -45,6 +45,8 @@ const ( FieldCodeChallenge = "code_challenge" // FieldCodeChallengeMethod holds the string denoting the code_challenge_method field in the database. FieldCodeChallengeMethod = "code_challenge_method" + // FieldLoginHint holds the string denoting the login_hint field in the database. + FieldLoginHint = "login_hint" // Table holds the table name of the authrequest in the database. Table = "auth_requests" ) @@ -71,6 +73,7 @@ var Columns = []string{ FieldExpiry, FieldCodeChallenge, FieldCodeChallengeMethod, + FieldLoginHint, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -90,6 +93,8 @@ var ( DefaultCodeChallenge string // DefaultCodeChallengeMethod holds the default value on creation for the "code_challenge_method" field. DefaultCodeChallengeMethod string + // DefaultLoginHint holds the default value on creation for the "login_hint" field. + DefaultLoginHint string // IDValidator is a validator for the "id" field. It is called by the builders before save. IDValidator func(string) error ) diff --git a/storage/ent/db/authrequest/where.go b/storage/ent/db/authrequest/where.go index 31ae7a0396..471af1fbae 100644 --- a/storage/ent/db/authrequest/where.go +++ b/storage/ent/db/authrequest/where.go @@ -204,6 +204,13 @@ func CodeChallengeMethod(v string) predicate.AuthRequest { }) } +// LoginHint applies equality check predicate on the "login_hint" field. It's identical to LoginHintEQ. +func LoginHint(v string) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldLoginHint), v)) + }) +} + // ClientIDEQ applies the EQ predicate on the "client_id" field. func ClientIDEQ(v string) predicate.AuthRequest { return predicate.AuthRequest(func(s *sql.Selector) { @@ -1675,6 +1682,117 @@ func CodeChallengeMethodContainsFold(v string) predicate.AuthRequest { }) } +// LoginHintEQ applies the EQ predicate on the "login_hint" field. +func LoginHintEQ(v string) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldLoginHint), v)) + }) +} + +// LoginHintNEQ applies the NEQ predicate on the "login_hint" field. +func LoginHintNEQ(v string) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldLoginHint), v)) + }) +} + +// LoginHintIn applies the In predicate on the "login_hint" field. +func LoginHintIn(vs ...string) predicate.AuthRequest { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.AuthRequest(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldLoginHint), v...)) + }) +} + +// LoginHintNotIn applies the NotIn predicate on the "login_hint" field. +func LoginHintNotIn(vs ...string) predicate.AuthRequest { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.AuthRequest(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldLoginHint), v...)) + }) +} + +// LoginHintGT applies the GT predicate on the "login_hint" field. +func LoginHintGT(v string) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldLoginHint), v)) + }) +} + +// LoginHintGTE applies the GTE predicate on the "login_hint" field. +func LoginHintGTE(v string) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldLoginHint), v)) + }) +} + +// LoginHintLT applies the LT predicate on the "login_hint" field. +func LoginHintLT(v string) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldLoginHint), v)) + }) +} + +// LoginHintLTE applies the LTE predicate on the "login_hint" field. +func LoginHintLTE(v string) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldLoginHint), v)) + }) +} + +// LoginHintContains applies the Contains predicate on the "login_hint" field. +func LoginHintContains(v string) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.Contains(s.C(FieldLoginHint), v)) + }) +} + +// LoginHintHasPrefix applies the HasPrefix predicate on the "login_hint" field. +func LoginHintHasPrefix(v string) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.HasPrefix(s.C(FieldLoginHint), v)) + }) +} + +// LoginHintHasSuffix applies the HasSuffix predicate on the "login_hint" field. +func LoginHintHasSuffix(v string) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.HasSuffix(s.C(FieldLoginHint), v)) + }) +} + +// LoginHintEqualFold applies the EqualFold predicate on the "login_hint" field. +func LoginHintEqualFold(v string) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.EqualFold(s.C(FieldLoginHint), v)) + }) +} + +// LoginHintContainsFold applies the ContainsFold predicate on the "login_hint" field. +func LoginHintContainsFold(v string) predicate.AuthRequest { + return predicate.AuthRequest(func(s *sql.Selector) { + s.Where(sql.ContainsFold(s.C(FieldLoginHint), v)) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.AuthRequest) predicate.AuthRequest { return predicate.AuthRequest(func(s *sql.Selector) { diff --git a/storage/ent/db/authrequest_create.go b/storage/ent/db/authrequest_create.go index e7a2c8cebe..2cdfe6d980 100644 --- a/storage/ent/db/authrequest_create.go +++ b/storage/ent/db/authrequest_create.go @@ -158,6 +158,20 @@ func (arc *AuthRequestCreate) SetNillableCodeChallengeMethod(s *string) *AuthReq return arc } +// SetLoginHint sets the "login_hint" field. +func (arc *AuthRequestCreate) SetLoginHint(s string) *AuthRequestCreate { + arc.mutation.SetLoginHint(s) + return arc +} + +// SetNillableLoginHint sets the "login_hint" field if the given value is not nil. +func (arc *AuthRequestCreate) SetNillableLoginHint(s *string) *AuthRequestCreate { + if s != nil { + arc.SetLoginHint(*s) + } + return arc +} + // SetID sets the "id" field. func (arc *AuthRequestCreate) SetID(s string) *AuthRequestCreate { arc.mutation.SetID(s) @@ -228,6 +242,10 @@ func (arc *AuthRequestCreate) defaults() { v := authrequest.DefaultCodeChallengeMethod arc.mutation.SetCodeChallengeMethod(v) } + if _, ok := arc.mutation.LoginHint(); !ok { + v := authrequest.DefaultLoginHint + arc.mutation.SetLoginHint(v) + } } // check runs all checks and user-defined validators on the builder. @@ -277,6 +295,9 @@ func (arc *AuthRequestCreate) check() error { if _, ok := arc.mutation.CodeChallengeMethod(); !ok { return &ValidationError{Name: "code_challenge_method", err: errors.New("db: missing required field \"code_challenge_method\"")} } + if _, ok := arc.mutation.LoginHint(); !ok { + return &ValidationError{Name: "login_hint", err: errors.New("db: missing required field \"login_hint\"")} + } if v, ok := arc.mutation.ID(); ok { if err := authrequest.IDValidator(v); err != nil { return &ValidationError{Name: "id", err: fmt.Errorf("db: validator failed for field \"id\": %w", err)} @@ -463,6 +484,14 @@ func (arc *AuthRequestCreate) createSpec() (*AuthRequest, *sqlgraph.CreateSpec) }) _node.CodeChallengeMethod = value } + if value, ok := arc.mutation.LoginHint(); ok { + _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: authrequest.FieldLoginHint, + }) + _node.LoginHint = value + } return _node, _spec } diff --git a/storage/ent/db/authrequest_update.go b/storage/ent/db/authrequest_update.go index 2d3f859443..4eb5b2b3fc 100644 --- a/storage/ent/db/authrequest_update.go +++ b/storage/ent/db/authrequest_update.go @@ -189,6 +189,20 @@ func (aru *AuthRequestUpdate) SetNillableCodeChallengeMethod(s *string) *AuthReq return aru } +// SetLoginHint sets the "login_hint" field. +func (aru *AuthRequestUpdate) SetLoginHint(s string) *AuthRequestUpdate { + aru.mutation.SetLoginHint(s) + return aru +} + +// SetNillableLoginHint sets the "login_hint" field if the given value is not nil. +func (aru *AuthRequestUpdate) SetNillableLoginHint(s *string) *AuthRequestUpdate { + if s != nil { + aru.SetLoginHint(*s) + } + return aru +} + // Mutation returns the AuthRequestMutation object of the builder. func (aru *AuthRequestUpdate) Mutation() *AuthRequestMutation { return aru.mutation @@ -420,6 +434,13 @@ func (aru *AuthRequestUpdate) sqlSave(ctx context.Context) (n int, err error) { Column: authrequest.FieldCodeChallengeMethod, }) } + if value, ok := aru.mutation.LoginHint(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: authrequest.FieldLoginHint, + }) + } if n, err = sqlgraph.UpdateNodes(ctx, aru.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{authrequest.Label} @@ -601,6 +622,20 @@ func (aruo *AuthRequestUpdateOne) SetNillableCodeChallengeMethod(s *string) *Aut return aruo } +// SetLoginHint sets the "login_hint" field. +func (aruo *AuthRequestUpdateOne) SetLoginHint(s string) *AuthRequestUpdateOne { + aruo.mutation.SetLoginHint(s) + return aruo +} + +// SetNillableLoginHint sets the "login_hint" field if the given value is not nil. +func (aruo *AuthRequestUpdateOne) SetNillableLoginHint(s *string) *AuthRequestUpdateOne { + if s != nil { + aruo.SetLoginHint(*s) + } + return aruo +} + // Mutation returns the AuthRequestMutation object of the builder. func (aruo *AuthRequestUpdateOne) Mutation() *AuthRequestMutation { return aruo.mutation @@ -856,6 +891,13 @@ func (aruo *AuthRequestUpdateOne) sqlSave(ctx context.Context) (_node *AuthReque Column: authrequest.FieldCodeChallengeMethod, }) } + if value, ok := aruo.mutation.LoginHint(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: authrequest.FieldLoginHint, + }) + } _node = &AuthRequest{config: aruo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index d5b1f535d9..ec8a56ae30 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -56,6 +56,7 @@ var ( {Name: "expiry", Type: field.TypeTime}, {Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, {Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "login_hint", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, } // AuthRequestsTable holds the schema information for the "auth_requests" table. AuthRequestsTable = &schema.Table{ diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index 7ccab3f2b3..c749d343f4 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -1180,6 +1180,7 @@ type AuthRequestMutation struct { expiry *time.Time code_challenge *string code_challenge_method *string + login_hint *string clearedFields map[string]struct{} done bool oldValue func(context.Context) (*AuthRequest, error) @@ -2007,6 +2008,42 @@ func (m *AuthRequestMutation) ResetCodeChallengeMethod() { m.code_challenge_method = nil } +// SetLoginHint sets the "login_hint" field. +func (m *AuthRequestMutation) SetLoginHint(s string) { + m.login_hint = &s +} + +// LoginHint returns the value of the "login_hint" field in the mutation. +func (m *AuthRequestMutation) LoginHint() (r string, exists bool) { + v := m.login_hint + if v == nil { + return + } + return *v, true +} + +// OldLoginHint returns the old "login_hint" field's value of the AuthRequest entity. +// If the AuthRequest object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthRequestMutation) OldLoginHint(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, fmt.Errorf("OldLoginHint is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, fmt.Errorf("OldLoginHint requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLoginHint: %w", err) + } + return oldValue.LoginHint, nil +} + +// ResetLoginHint resets all changes to the "login_hint" field. +func (m *AuthRequestMutation) ResetLoginHint() { + m.login_hint = nil +} + // Op returns the operation name. func (m *AuthRequestMutation) Op() Op { return m.op @@ -2021,7 +2058,7 @@ func (m *AuthRequestMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AuthRequestMutation) Fields() []string { - fields := make([]string, 0, 19) + fields := make([]string, 0, 20) if m.client_id != nil { fields = append(fields, authrequest.FieldClientID) } @@ -2079,6 +2116,9 @@ func (m *AuthRequestMutation) Fields() []string { if m.code_challenge_method != nil { fields = append(fields, authrequest.FieldCodeChallengeMethod) } + if m.login_hint != nil { + fields = append(fields, authrequest.FieldLoginHint) + } return fields } @@ -2125,6 +2165,8 @@ func (m *AuthRequestMutation) Field(name string) (ent.Value, bool) { return m.CodeChallenge() case authrequest.FieldCodeChallengeMethod: return m.CodeChallengeMethod() + case authrequest.FieldLoginHint: + return m.LoginHint() } return nil, false } @@ -2172,6 +2214,8 @@ func (m *AuthRequestMutation) OldField(ctx context.Context, name string) (ent.Va return m.OldCodeChallenge(ctx) case authrequest.FieldCodeChallengeMethod: return m.OldCodeChallengeMethod(ctx) + case authrequest.FieldLoginHint: + return m.OldLoginHint(ctx) } return nil, fmt.Errorf("unknown AuthRequest field %s", name) } @@ -2314,6 +2358,13 @@ func (m *AuthRequestMutation) SetField(name string, value ent.Value) error { } m.SetCodeChallengeMethod(v) return nil + case authrequest.FieldLoginHint: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLoginHint(v) + return nil } return fmt.Errorf("unknown AuthRequest field %s", name) } @@ -2447,6 +2498,9 @@ func (m *AuthRequestMutation) ResetField(name string) error { case authrequest.FieldCodeChallengeMethod: m.ResetCodeChallengeMethod() return nil + case authrequest.FieldLoginHint: + m.ResetLoginHint() + return nil } return fmt.Errorf("unknown AuthRequest field %s", name) } diff --git a/storage/ent/db/runtime.go b/storage/ent/db/runtime.go index 49f4157ac0..0566e0fef6 100644 --- a/storage/ent/db/runtime.go +++ b/storage/ent/db/runtime.go @@ -82,6 +82,10 @@ func init() { authrequestDescCodeChallengeMethod := authrequestFields[19].Descriptor() // authrequest.DefaultCodeChallengeMethod holds the default value on creation for the code_challenge_method field. authrequest.DefaultCodeChallengeMethod = authrequestDescCodeChallengeMethod.Default.(string) + // authrequestDescLoginHint is the schema descriptor for login_hint field. + authrequestDescLoginHint := authrequestFields[20].Descriptor() + // authrequest.DefaultLoginHint holds the default value on creation for the login_hint field. + authrequest.DefaultLoginHint = authrequestDescLoginHint.Default.(string) // authrequestDescID is the schema descriptor for id field. authrequestDescID := authrequestFields[0].Descriptor() // authrequest.IDValidator is a validator for the "id" field. It is called by the builders before save. diff --git a/storage/ent/schema/authrequest.go b/storage/ent/schema/authrequest.go index a16fe55180..7af677c8cc 100644 --- a/storage/ent/schema/authrequest.go +++ b/storage/ent/schema/authrequest.go @@ -85,6 +85,9 @@ func (AuthRequest) Fields() []ent.Field { field.Text("code_challenge_method"). SchemaType(textSchema). Default(""), + field.Text("login_hint"). + SchemaType(textSchema). + Default(""), } } diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 5a234f9deb..009c2f914a 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -131,10 +131,11 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { claims_email, claims_email_verified, claims_groups, connector_id, connector_data, expiry, - code_challenge, code_challenge_method + code_challenge, code_challenge_method, + login_hint ) values ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20 + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21 ); `, a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, @@ -144,6 +145,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { a.ConnectorID, a.ConnectorData, a.Expiry, a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, + a.LoginHint, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -175,8 +177,9 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) claims_groups = $14, connector_id = $15, connector_data = $16, expiry = $17, - code_challenge = $18, code_challenge_method = $19 - where id = $20; + code_challenge = $18, code_challenge_method = $19, + login_hint = $20 + where id = $21; `, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, a.ForceApprovalPrompt, a.LoggedIn, @@ -186,6 +189,7 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) a.ConnectorID, a.ConnectorData, a.Expiry, a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, + a.LoginHint, r.ID, ) if err != nil { @@ -207,7 +211,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, expiry, - code_challenge, code_challenge_method + code_challenge, code_challenge_method, login_hint from auth_request where id = $1; `, id).Scan( &a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State, @@ -216,7 +220,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups), &a.ConnectorID, &a.ConnectorData, &a.Expiry, - &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, + &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, &a.LoginHint, ) if err != nil { if err == sql.ErrNoRows { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 498db25276..db48a68f9e 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -281,4 +281,11 @@ var migrations = []migration{ add column obsolete_token text default '';`, }, }, + { + stmts: []string{ + ` + alter table auth_request + add column login_hint text not null default '';`, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index cdd83ca6ea..4de9f1f55a 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -199,6 +199,7 @@ type AuthRequest struct { RedirectURI string Nonce string State string + LoginHint string // The client has indicated that the end user must be shown an approval prompt // on all requests. The server cannot cache their initial action for subsequent From 1efa47174919f872be7abfcf90f51d2ec85db20a Mon Sep 17 00:00:00 2001 From: Armaan Varadaraj Date: Tue, 12 Oct 2021 10:21:38 -0700 Subject: [PATCH 04/12] add offline_access scope to oidc connector --- connector/oidc/oidc.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index ec00814b19..1a0e5cb38b 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -130,6 +130,8 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e scopes = append(scopes, "profile", "email") } + scopes = append(scopes, "offline_access") + // PromptType should be "consent" by default, if not set if c.PromptType == "" { c.PromptType = "consent" From f73bd86bbf4c6ffda696476a825aa15ce14438dd Mon Sep 17 00:00:00 2001 From: Armaan Varadaraj Date: Tue, 12 Oct 2021 16:00:25 -0700 Subject: [PATCH 05/12] add some logging --- connector/oidc/oidc.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 1a0e5cb38b..80d74ea2b6 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -273,6 +273,7 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit } token, err := c.oauth2Config.TokenSource(ctx, t).Token() if err != nil { + c.logger.Error("sent refresh token %s", t.RefreshToken) return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err) } @@ -383,6 +384,8 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I } } + c.logger.Error("got id token %s", token.Extra("id_token").(string)) + c.logger.Error("got refresh token %s", token.RefreshToken) cd := connectorData{ RefreshToken: []byte(token.RefreshToken), } From 2060aa2debe90c72d38db0dd5f01f076ca47304c Mon Sep 17 00:00:00 2001 From: Armaan Varadaraj Date: Tue, 12 Oct 2021 18:20:46 -0700 Subject: [PATCH 06/12] dont refresh with connector --- server/refreshhandlers.go | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 8ea7ea9ef1..eb1fb72ac2 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -170,26 +170,6 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre ConnectorData: connectorData, } - // user's token was previously updated by a connector and is allowed to reuse - // it is excessive to refresh identity in upstream - if s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) && token.Token == refresh.ObsoleteToken { - return ident, nil - } - - // Can the connector refresh the identity? If so, attempt to refresh the data - // in the connector. - // - // TODO(ericchiang): We may want a strict mode where connectors that don't implement - // this interface can't perform refreshing. - if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { - newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident) - if err != nil { - s.logger.Errorf("failed to refresh identity: %v", err) - return connector.Identity{}, newInternalServerError() - } - ident = newIdent - } - return ident, nil } From 825fea71e695c4cb2d924920c5a114151f7177a4 Mon Sep 17 00:00:00 2001 From: Armaan Varadaraj Date: Tue, 12 Oct 2021 18:27:34 -0700 Subject: [PATCH 07/12] lint --- server/refreshhandlers.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index eb1fb72ac2..7fef35975f 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -154,12 +154,6 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre connectorData = session.ConnectorData } - conn, err := s.getConnector(refresh.ConnectorID) - if err != nil { - s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) - return connector.Identity{}, newInternalServerError() - } - ident := connector.Identity{ UserID: refresh.Claims.UserID, Username: refresh.Claims.Username, From 31634c5ecc5fff17c620a6f7dce9b8dde2ecb14f Mon Sep 17 00:00:00 2001 From: Armaan Varadaraj Date: Mon, 18 Oct 2021 09:32:44 -0700 Subject: [PATCH 08/12] remove offline acess type --- connector/oidc/oidc.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 80d74ea2b6..86364c350c 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -213,10 +213,6 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string, opts = append(opts, oauth2.SetAuthURLParam("hd", preferredDomain)) } - if s.OfflineAccess { - opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType)) - } - for p := range forwardParams { paramApproved := false for _, approvedParam := range c.forwardedLoginParams { From 3f72bb841e639abc7313d9494bbe5b0879078c8d Mon Sep 17 00:00:00 2001 From: Armaan Varadaraj Date: Mon, 18 Oct 2021 10:09:18 -0700 Subject: [PATCH 09/12] Revert "lint" This reverts commit 825fea71e695c4cb2d924920c5a114151f7177a4. --- server/refreshhandlers.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 7fef35975f..eb1fb72ac2 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -154,6 +154,12 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre connectorData = session.ConnectorData } + conn, err := s.getConnector(refresh.ConnectorID) + if err != nil { + s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) + return connector.Identity{}, newInternalServerError() + } + ident := connector.Identity{ UserID: refresh.Claims.UserID, Username: refresh.Claims.Username, From 6981cae1f7f10b9dcf205fe34bccd9e474ca0540 Mon Sep 17 00:00:00 2001 From: Armaan Varadaraj Date: Mon, 18 Oct 2021 10:09:26 -0700 Subject: [PATCH 10/12] Revert "dont refresh with connector" This reverts commit 2060aa2debe90c72d38db0dd5f01f076ca47304c. --- server/refreshhandlers.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index eb1fb72ac2..8ea7ea9ef1 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -170,6 +170,26 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre ConnectorData: connectorData, } + // user's token was previously updated by a connector and is allowed to reuse + // it is excessive to refresh identity in upstream + if s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) && token.Token == refresh.ObsoleteToken { + return ident, nil + } + + // Can the connector refresh the identity? If so, attempt to refresh the data + // in the connector. + // + // TODO(ericchiang): We may want a strict mode where connectors that don't implement + // this interface can't perform refreshing. + if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { + newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident) + if err != nil { + s.logger.Errorf("failed to refresh identity: %v", err) + return connector.Identity{}, newInternalServerError() + } + ident = newIdent + } + return ident, nil } From 37923ca5568e1c994eb234ca97865e4b85327d76 Mon Sep 17 00:00:00 2001 From: Armaan Varadaraj Date: Wed, 20 Oct 2021 11:53:47 -0700 Subject: [PATCH 11/12] remove token logs --- connector/oidc/oidc.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 86364c350c..af372b08ff 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -269,7 +269,6 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit } token, err := c.oauth2Config.TokenSource(ctx, t).Token() if err != nil { - c.logger.Error("sent refresh token %s", t.RefreshToken) return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err) } @@ -380,8 +379,6 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I } } - c.logger.Error("got id token %s", token.Extra("id_token").(string)) - c.logger.Error("got refresh token %s", token.RefreshToken) cd := connectorData{ RefreshToken: []byte(token.RefreshToken), } From d7d7e5baea44966bc336b19e7d86672ae8a4aa97 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 19 Jul 2022 15:27:57 +0000 Subject: [PATCH 12/12] build(deps): bump docker/setup-buildx-action from 1 to 2 Bumps [docker/setup-buildx-action](https://github.com/docker/setup-buildx-action) from 1 to 2. - [Release notes](https://github.com/docker/setup-buildx-action/releases) - [Commits](https://github.com/docker/setup-buildx-action/compare/v1...v2) --- updated-dependencies: - dependency-name: docker/setup-buildx-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/docker.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 26c7a334de..58eaf111d2 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -15,7 +15,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Calculate Docker image tags id: tags @@ -49,7 +49,7 @@ jobs: platforms: all - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v2 with: install: true version: latest @@ -57,7 +57,7 @@ jobs: driver-opts: image=moby/buildkit:master - name: Login to GitHub Container Registry - uses: docker/login-action@v1 + uses: docker/login-action@v2 with: registry: ghcr.io username: ${{ github.repository_owner }} @@ -65,14 +65,14 @@ jobs: if: github.event_name == 'push' - name: Login to Docker Hub - uses: docker/login-action@v1 + uses: docker/login-action@v2 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} if: github.event_name == 'push' - name: Build and push - uses: docker/build-push-action@v2 + uses: docker/build-push-action@v3 with: context: . platforms: linux/amd64,linux/arm/v7,linux/arm64