Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ const (
DelegatedClientSecretFlag = "delegated-client-secret"
DelegatedIssuerFlag = "delegated-issuer"
BaseUrlFlag = "base-url"
AuthIssuersFlag = "auth-issuers"
ListenFlag = "listen"
SigningKeyFlag = "signing-key"
ConfigFlag = "config"
Expand Down Expand Up @@ -108,6 +109,7 @@ func newServeCommand() *cobra.Command {
cmd.Flags().String(DelegatedClientIDFlag, "", "Delegated OIDC client id")
cmd.Flags().String(DelegatedClientSecretFlag, "", "Delegated OIDC client secret")
cmd.Flags().String(BaseUrlFlag, "http://localhost:8080", "Base service url")
cmd.Flags().StringSlice(AuthIssuersFlag, []string{}, "Additional trusted issuer URLs for multi-domain support")
cmd.Flags().String(SigningKeyFlag, defaultSigningKey, "Signing key")
cmd.Flags().String(ListenFlag, ":8080", "Listening address")
cmd.Flags().String(ConfigFlag, "", "Config file name without extension")
Expand All @@ -130,6 +132,9 @@ func runServe(cmd *cobra.Command, _ []string) error {
return errors.New("base url must be defined")
}

additionalIssuers, _ := cmd.Flags().GetStringSlice(AuthIssuersFlag)
trustedIssuers := append([]string{baseUrl}, additionalIssuers...)

signingKey, _ := cmd.Flags().GetString(SigningKeyFlag)
if signingKey == "" {
return errors.New("signing key must be defined")
Expand Down Expand Up @@ -182,10 +187,11 @@ func runServe(cmd *cobra.Command, _ []string) error {
otlpHttpClientModule(service.IsDebug(cmd)),
fx.Supply(fx.Annotate(cmd.Context(), fx.As(new(context.Context)))),
sqlstorage.Module(*connectionOptions, key, service.IsDebug(cmd), o.Clients...),
oidc.Module(key, baseUrl, o.Clients...),
oidc.Module(key, baseUrl, trustedIssuers, o.Clients...),
api.Module(
listen,
baseUrl,
trustedIssuers,
sharedapi.ServiceInfo{
Version: Version,
Debug: service.IsDebug(cmd),
Expand Down
15 changes: 12 additions & 3 deletions pkg/api/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ import (
"github.com/formancehq/go-libs/v3/health"
"github.com/formancehq/go-libs/v3/httpserver"
"github.com/formancehq/go-libs/v3/logging"
authoidc "github.com/formancehq/auth/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"go.uber.org/fx"
)

func CreateRootRouter(
logger logging.Logger,
issuer string,
defaultIssuer string,
trustedIssuers []string,
debug bool,
) chi.Router {
rootRouter := chi.NewRouter()
Expand All @@ -31,6 +33,13 @@ func CreateRootRouter(
})
rootRouter.Use(func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host := authoidc.HostFromRequest(r)
issuer := authoidc.IssuerForHost(host, defaultIssuer, trustedIssuers)
// Rewrite r.Host so ZITADEL's dynamic IssuerFromRequest reads
// the correct host for discovery and other OIDC endpoints.
if h := authoidc.HostFromIssuer(issuer); h != "" {
r.Host = h
}
handler.ServeHTTP(w, r.WithContext(
op.ContextWithIssuer(r.Context(), issuer),
))
Expand All @@ -43,12 +52,12 @@ func addInfoRoute(router chi.Router, serviceInfo api.ServiceInfo) {
router.Get("/_info", api.InfoHandler(serviceInfo))
}

func Module(addr, issuer string, serviceInfo api.ServiceInfo, debug bool) fx.Option {
func Module(addr, defaultIssuer string, trustedIssuers []string, serviceInfo api.ServiceInfo, debug bool) fx.Option {
return fx.Options(
health.Module(),
fx.Supply(serviceInfo),
fx.Provide(func(logger logging.Logger) chi.Router {
return CreateRootRouter(logger, issuer, debug)
return CreateRootRouter(logger, defaultIssuer, trustedIssuers, debug)
}),
fx.Invoke(
addInfoRoute,
Expand Down
14 changes: 9 additions & 5 deletions pkg/oidc/grant_type_bearer.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@ type JWTProfileVerifier interface {

type JWTAuthorizationGrantExchanger interface {
op.Exchanger
JWTProfileVerifier() JWTProfileVerifier
JWTProfileVerifier(issuer string) JWTProfileVerifier
}

func grantTypeBearer(issuer string, p JWTAuthorizationGrantExchanger) http.HandlerFunc {
func grantTypeBearer(p JWTAuthorizationGrantExchanger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
issuer := op.IssuerFromContext(r.Context())

profileRequest, err := op.ParseJWTProfileGrantRequest(r, p.Decoder())
if err != nil {
op.RequestError(w, r, err)
Expand All @@ -95,7 +97,7 @@ func grantTypeBearer(issuer string, p JWTAuthorizationGrantExchanger) http.Handl
}
}

tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, p.JWTProfileVerifier())
tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, p.JWTProfileVerifier(issuer))
if err != nil {
op.RequestError(w, r, err)
return
Expand All @@ -115,7 +117,7 @@ func grantTypeBearer(issuer string, p JWTAuthorizationGrantExchanger) http.Handl

tokenRequest.Scopes = tokens.Scopes

resp, err := CreateJWTTokenResponse(r.Context(), issuer, tokenRequest, p, client)
resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, p, client)
if err != nil {
op.RequestError(w, r, err)
return
Expand All @@ -136,7 +138,9 @@ func ParseAssertion(assertion string) (*oidc.AccessTokenClaims, error) {
return claims, nil
}

func CreateJWTTokenResponse(ctx context.Context, issuer string, tokenRequest *oidc.JWTTokenRequest, creator op.TokenCreator, client op.Client) (*oidc.AccessTokenResponse, error) {
func CreateJWTTokenResponse(ctx context.Context, tokenRequest *oidc.JWTTokenRequest, creator op.TokenCreator, client op.Client) (*oidc.AccessTokenResponse, error) {
issuer := op.IssuerFromContext(ctx)

id, exp, err := creator.Storage().CreateAccessToken(ctx, tokenRequest)
if err != nil {
return nil, err
Expand Down
34 changes: 34 additions & 0 deletions pkg/oidc/issuer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package oidc

import (
"net/http"
"net/url"
)

// HostFromRequest returns the effective host, preferring X-Forwarded-Host over Host.
func HostFromRequest(r *http.Request) string {
if fwd := r.Header.Get("X-Forwarded-Host"); fwd != "" {
return fwd
}
return r.Host
}

// IssuerForHost returns the trusted issuer matching the given host, or the default.
func IssuerForHost(host, defaultIssuer string, trustedIssuers []string) string {
for _, issuer := range trustedIssuers {
u, err := url.Parse(issuer)
if err == nil && u.Host == host {
return issuer
}
}
return defaultIssuer
}

// HostFromIssuer extracts the host component from an issuer URL.
func HostFromIssuer(issuer string) string {
u, err := url.Parse(issuer)
if err != nil {
return ""
}
return u.Host
}
4 changes: 2 additions & 2 deletions pkg/oidc/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"go.uber.org/fx"
)

func Module(privateKey *rsa.PrivateKey, issuer string, staticClients ...auth.StaticClient) fx.Option {
func Module(privateKey *rsa.PrivateKey, issuer string, trustedIssuers []string, staticClients ...auth.StaticClient) fx.Option {
return fx.Options(
fx.Invoke(fx.Annotate(func(router chi.Router, provider op.OpenIDProvider,
storage Storage, relyingParty rp.RelyingParty) {
Expand All @@ -37,7 +37,7 @@ func Module(privateKey *rsa.PrivateKey, issuer string, staticClients ...auth.Sta
}
}

return NewOpenIDProvider(storage, issuer, configuration.Issuer, keySet)
return NewOpenIDProvider(storage, issuer, trustedIssuers, configuration.Issuer, keySet)
}, fx.ParamTags(``, ``, `optional:"true"`))),
)
}
2 changes: 1 addition & 1 deletion pkg/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func withServer(t *testing.T, fn func(m *mockoidc.MockOIDC, storage *sqlstorage.
require.NoError(t, err)

// Construct our oidc provider
provider, err := oidc.NewOpenIDProvider(storageFacade, serverUrl, mockOIDC.Issuer(), keySet)
provider, err := oidc.NewOpenIDProvider(storageFacade, serverUrl, []string{serverUrl}, mockOIDC.Issuer(), keySet)
require.NoError(t, err)

// Create the router
Expand Down
28 changes: 20 additions & 8 deletions pkg/oidc/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oidc
import (
"crypto/sha256"
"net/http"
"net/url"
"time"

"github.com/zitadel/oidc/v2/pkg/oidc"
Expand Down Expand Up @@ -47,12 +48,12 @@ type provider struct {
op.OpenIDProvider
delegatedIssuerJsonWebKeySet jose.JSONWebKeySet
delegatedIssuer string
issuer string
trustedIssuers []string
}

func (p provider) JWTProfileVerifier() JWTProfileVerifier {
func (p provider) JWTProfileVerifier(issuer string) JWTProfileVerifier {
return &verifier{
issuer: p.issuer,
issuer: issuer,
delegatedIssuer: p.delegatedIssuer,
mat: time.Hour,
offset: 0,
Expand All @@ -62,9 +63,14 @@ func (p provider) JWTProfileVerifier() JWTProfileVerifier {

var _ JWTAuthorizationGrantExchanger = (*provider)(nil)

func NewOpenIDProvider(storage op.Storage, issuer, delegatedIssuer string, delegatedIssuerJsonWebKeySet *jose.JSONWebKeySet) (op.OpenIDProvider, error) {
func NewOpenIDProvider(storage op.Storage, issuer string, trustedIssuers []string, delegatedIssuer string, delegatedIssuerJsonWebKeySet *jose.JSONWebKeySet) (op.OpenIDProvider, error) {
var p op.OpenIDProvider

parsedIssuer, err := url.Parse(issuer)
if err != nil {
return nil, err
}

interceptors := make([]op.Option, 0)
if delegatedIssuer != "" {
interceptors = append(interceptors, op.WithHttpInterceptors(func(handler http.Handler) http.Handler {
Expand All @@ -73,8 +79,8 @@ func NewOpenIDProvider(storage op.Storage, issuer, delegatedIssuer string, deleg
// as the library does not implement what we needs
if r.URL.Path == op.DefaultEndpoints.Token.Relative() &&
r.FormValue("grant_type") == string(oidc.GrantTypeBearer) {
grantTypeBearer(issuer, &provider{
issuer: issuer,
grantTypeBearer(&provider{
trustedIssuers: trustedIssuers,
OpenIDProvider: p,
delegatedIssuerJsonWebKeySet: *delegatedIssuerJsonWebKeySet,
delegatedIssuer: delegatedIssuer,
Expand All @@ -86,9 +92,15 @@ func NewOpenIDProvider(storage op.Storage, issuer, delegatedIssuer string, deleg

}))
}
interceptors = append(interceptors, op.WithAllowInsecure())

p, err := op.NewOpenIDProvider(issuer, &op.Config{
if parsedIssuer.Scheme == "http" {
interceptors = append(interceptors, op.WithAllowInsecure())
}

// Use NewDynamicOpenIDProvider so ZITADEL reads the issuer from r.Host
// (which is set by the chi middleware based on trusted issuers) instead
// of using a static issuer string.
p, err = op.NewDynamicOpenIDProvider(parsedIssuer.Path, &op.Config{
CryptoKey: sha256.Sum256([]byte("test")),
DefaultLogoutRedirectURI: pathLoggedOut,
CodeMethodS256: true,
Expand Down
Loading