diff --git a/opentdf-core-mode.yaml b/opentdf-core-mode.yaml index 98401d19ee..4350cbb70a 100644 --- a/opentdf-core-mode.yaml +++ b/opentdf-core-mode.yaml @@ -11,6 +11,14 @@ logger: level: debug type: text output: stderr +audit: + # Optional JWT claim mappings for audit enrichment. + # Paths use dot notation over the emitted audit log shape. + # jwt_claim_mappings: + # - claim: sub + # path: eventMetaData.requester.sub + # - claim: realm_access.roles + # path: eventMetaData.requester.roles # DB and Server configurations are defaulted for local development # db: # host: localhost @@ -57,4 +65,4 @@ server: maxage: 3600 grpc: reflectionEnabled: true # Default is false - port: 8383 \ No newline at end of file + port: 8383 diff --git a/opentdf-dev.yaml b/opentdf-dev.yaml index 93a471101f..dd2050e30f 100644 --- a/opentdf-dev.yaml +++ b/opentdf-dev.yaml @@ -2,6 +2,14 @@ logger: level: debug type: text output: stderr +audit: + # Optional JWT claim mappings for audit enrichment. + # Paths use dot notation over the emitted audit log shape. + # jwt_claim_mappings: + # - claim: sub + # path: eventMetaData.requester.sub + # - claim: realm_access.roles + # path: eventMetaData.requester.roles # DB and Server configurations are defaulted for local development # db: # host: localhost diff --git a/opentdf-ers-mode.yaml b/opentdf-ers-mode.yaml index 220317039c..95186c759b 100644 --- a/opentdf-ers-mode.yaml +++ b/opentdf-ers-mode.yaml @@ -5,6 +5,14 @@ logger: level: debug type: text output: stderr +audit: + # Optional JWT claim mappings for audit enrichment. + # Paths use dot notation over the emitted audit log shape. + # jwt_claim_mappings: + # - claim: sub + # path: eventMetaData.requester.sub + # - claim: realm_access.roles + # path: eventMetaData.requester.roles services: entityresolution: log_level: info diff --git a/opentdf-ers-test.yaml b/opentdf-ers-test.yaml index 38d4833592..e46eb43f06 100644 --- a/opentdf-ers-test.yaml +++ b/opentdf-ers-test.yaml @@ -12,6 +12,14 @@ mode: standalone logger: level: info type: json +audit: + # Optional JWT claim mappings for audit enrichment. + # Paths use dot notation over the emitted audit log shape. + # jwt_claim_mappings: + # - claim: sub + # path: eventMetaData.requester.sub + # - claim: realm_access.roles + # path: eventMetaData.requester.roles crypto: type: standard @@ -256,4 +264,4 @@ development: # Allow insecure connections for testing allow_insecure: true # Disable some validations for easier testing - strict_validation: false \ No newline at end of file + strict_validation: false diff --git a/opentdf-example.yaml b/opentdf-example.yaml index a0b97da826..a397f20ff3 100644 --- a/opentdf-example.yaml +++ b/opentdf-example.yaml @@ -2,6 +2,14 @@ logger: level: debug type: text output: stdout +audit: + # Optional JWT claim mappings for audit enrichment. + # Paths use dot notation over the emitted audit log shape. + # jwt_claim_mappings: + # - claim: sub + # path: eventMetaData.requester.sub + # - claim: realm_access.roles + # path: eventMetaData.requester.roles # DB and Server configurations are defaulted for local development db: host: opentdfdb diff --git a/opentdf-kas-mode.yaml b/opentdf-kas-mode.yaml index ebcfb6f0c2..8359710fdf 100644 --- a/opentdf-kas-mode.yaml +++ b/opentdf-kas-mode.yaml @@ -11,6 +11,14 @@ logger: level: debug type: text output: stderr +audit: + # Optional JWT claim mappings for audit enrichment. + # Paths use dot notation over the emitted audit log shape. + # jwt_claim_mappings: + # - claim: sub + # path: eventMetaData.requester.sub + # - claim: realm_access.roles + # path: eventMetaData.requester.roles security: unsafe: # Increase only when diagnosing clock drift issues; default is 1m diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index b5c68ba837..449e252689 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -438,10 +438,52 @@ func IPCMetadataClientInterceptor(log *logger.Logger) connect.UnaryInterceptorFu }) } +// rehydrateIPCAuthContext reconstructs pkg/auth context from propagated IPC metadata. +// It only transports an already-authenticated token across an internal hop; it does not +// validate the token again. Network-facing requests must continue to rely on the +// ConnectUnaryServerInterceptor auth middleware for validation. +func rehydrateIPCAuthContext(ctx context.Context, l *logger.Logger) (context.Context, error) { + if ctxAuth.GetAccessTokenFromContext(ctx, l) != nil && ctxAuth.GetRawAccessTokenFromContext(ctx, l) != "" { + return ctx, nil + } + + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return ctx, nil + } + + rawToken := rawAccessTokenFromIncomingMetadata(md) + if rawToken == "" { + return ctx, nil + } + + parsed, err := jwt.Parse([]byte(rawToken), jwt.WithVerify(false), jwt.WithValidate(false)) + if err != nil { + if l != nil { + l.ErrorContext(ctx, "failed to rehydrate IPC access token from metadata", slog.Any("error", err)) + } + return ctx, fmt.Errorf("rehydrate IPC access token from metadata: %w", err) + } + + return ctxAuth.ContextWithAuthNInfo(ctx, ctxAuth.GetJWKFromContext(ctx, l), parsed, rawToken), nil +} + +func rawAccessTokenFromIncomingMetadata(md metadata.MD) string { + if accessTokens := md.Get(ctxAuth.AccessTokenKey); len(accessTokens) > 0 && accessTokens[0] != "" { + return accessTokens[0] + } + + authHeaders := md.Get("authorization") + if len(authHeaders) == 0 || authHeaders[0] == "" { + return "" + } + return strings.TrimPrefix(strings.TrimPrefix(authHeaders[0], "Bearer "), "DPoP ") +} + // IPCUnaryServerInterceptor is a grpc interceptor that: -// 1. verifies the token in the metadata -// 2. reauthorizes the token if the route is in the list -// 3. translates known IPC Connect request headers back to context metadata for downstream consumers +// 1. translates known IPC Connect request headers back to incoming metadata +// 2. reauthorizes routes that are configured for IPC reauth +// 3. rehydrates auth context from propagated incoming metadata without revalidating it func (a Authentication) IPCUnaryServerInterceptor() connect.UnaryInterceptorFunc { interceptor := func(next connect.UnaryFunc) connect.UnaryFunc { return connect.UnaryFunc(func( @@ -467,6 +509,10 @@ func (a Authentication) IPCUnaryServerInterceptor() connect.UnaryInterceptorFunc if err != nil { return nil, err } + nextCtx, err = rehydrateIPCAuthContext(nextCtx, a.logger) + if err != nil { + return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("invalid IPC authentication context")) + } return next(nextCtx, req) }) } diff --git a/service/internal/auth/authn_ipc_metadata_interceptor_test.go b/service/internal/auth/authn_ipc_metadata_interceptor_test.go index 16d15c3e26..2bd5d8b486 100644 --- a/service/internal/auth/authn_ipc_metadata_interceptor_test.go +++ b/service/internal/auth/authn_ipc_metadata_interceptor_test.go @@ -2,6 +2,7 @@ package auth import ( "context" + "errors" "testing" "connectrpc.com/connect" @@ -184,6 +185,10 @@ func (m *mockAnyRequest) Any() any { func TestIPCUnaryServerInterceptor(t *testing.T) { testLogger := logger.CreateTestLogger() + validJWT, err := jwt.NewBuilder().Subject("ipc-user").Build() + require.NoError(t, err) + validRawToken, err := jwt.Sign(validJWT, jwt.WithInsecureNoSignature()) + require.NoError(t, err) // Create a minimal authentication instance auth := &Authentication{ @@ -201,7 +206,7 @@ func TestIPCUnaryServerInterceptor(t *testing.T) { setupRequest: func() connect.AnyRequest { req := connect.NewRequest(&kas.PublicKeyRequest{}) req.Header().Set(canonicalIPCHeaderClientID, "test-client-from-header") - req.Header().Set(canonicalIPCHeaderAccessToken, "test-token-from-header") + req.Header().Set(canonicalIPCHeaderAccessToken, string(validRawToken)) return &mockAnyRequest{ Request: req, isClient: false, @@ -225,7 +230,7 @@ func TestIPCUnaryServerInterceptor(t *testing.T) { setupRequest: func() connect.AnyRequest { req := connect.NewRequest(&kas.PublicKeyRequest{}) req.Header().Set(canonicalIPCHeaderClientID, "merged-client-id") - req.Header().Set(canonicalIPCHeaderAccessToken, "merged-token") + req.Header().Set(canonicalIPCHeaderAccessToken, string(validRawToken)) return &mockAnyRequest{ Request: req, isClient: false, @@ -251,8 +256,13 @@ func TestIPCUnaryServerInterceptor(t *testing.T) { for _, key := range tt.expectedIncomingMDKeys { assert.NotEmpty(t, md.Get(key), "metadata key %s should exist", key) } + retrievedJWT := ctxAuth.GetAccessTokenFromContext(postInterceptorCtx, testLogger) + require.NotNil(t, retrievedJWT) + assert.Equal(t, "ipc-user", retrievedJWT.Subject()) + assert.Equal(t, string(validRawToken), ctxAuth.GetRawAccessTokenFromContext(postInterceptorCtx, testLogger)) } else { assert.Zero(t, md.Len()) + assert.Nil(t, ctxAuth.GetAccessTokenFromContext(postInterceptorCtx, testLogger)) } return connect.NewResponse(&kas.PublicKeyResponse{}), nil } @@ -267,19 +277,22 @@ func TestIPCUnaryServerInterceptor(t *testing.T) { func TestIPCUnaryServerInterceptor_Integration(t *testing.T) { testLogger := logger.CreateTestLogger() + mockJWT, err := jwt.NewBuilder().Subject("integration-user").Build() + require.NoError(t, err) + rawToken, err := jwt.Sign(mockJWT, jwt.WithInsecureNoSignature()) + require.NoError(t, err) auth := &Authentication{ logger: testLogger, ipcReauthRoutes: []string{}, } - t.Run("clientID and access token from headers available in context metadata", func(t *testing.T) { + t.Run("clientID and access token from headers available in context metadata and auth context", func(t *testing.T) { clientID := "integration-client-id" - accessToken := "integration-access-token" req := connect.NewRequest(&kas.PublicKeyRequest{}) req.Header().Set(canonicalIPCHeaderClientID, clientID) - req.Header().Set(canonicalIPCHeaderAccessToken, accessToken) + req.Header().Set(canonicalIPCHeaderAccessToken, string(rawToken)) wrappedReq := &mockAnyRequest{ Request: req, @@ -290,7 +303,7 @@ func TestIPCUnaryServerInterceptor_Integration(t *testing.T) { ctx := t.Context() - var receivedClientID, receivedAccessToken string + var receivedClientID, receivedAccessToken, receivedSubject string mockNext := func(ctx context.Context, _ connect.AnyRequest) (connect.AnyResponse, error) { md, ok := metadata.FromIncomingContext(ctx) require.True(t, ok) @@ -302,6 +315,9 @@ func TestIPCUnaryServerInterceptor_Integration(t *testing.T) { if len(accessTokens) > 0 { receivedAccessToken = accessTokens[0] } + retrievedJWT := ctxAuth.GetAccessTokenFromContext(ctx, testLogger) + require.NotNil(t, retrievedJWT) + receivedSubject = retrievedJWT.Subject() return connect.NewResponse(&kas.PublicKeyResponse{}), nil } @@ -310,6 +326,52 @@ func TestIPCUnaryServerInterceptor_Integration(t *testing.T) { require.NoError(t, err) assert.Equal(t, clientID, receivedClientID) - assert.Equal(t, accessToken, receivedAccessToken) + assert.Equal(t, string(rawToken), receivedAccessToken) + assert.Equal(t, "integration-user", receivedSubject) + }) + + t.Run("invalid access token header fails rehydration", func(t *testing.T) { + req := connect.NewRequest(&kas.PublicKeyRequest{}) + req.Header().Set(canonicalIPCHeaderAccessToken, "not-a-jwt") + + wrappedReq := &mockAnyRequest{ + Request: req, + isClient: false, + } + + interceptor := auth.IPCUnaryServerInterceptor() + interceptorFunc := interceptor(func(context.Context, connect.AnyRequest) (connect.AnyResponse, error) { + t.Fatal("next handler should not be called on invalid token") + return nil, errors.New("unreachable") + }) + + _, err := interceptorFunc(t.Context(), wrappedReq) + require.Error(t, err) + + var connectErr *connect.Error + require.ErrorAs(t, err, &connectErr) + assert.Equal(t, connect.CodeUnauthenticated, connectErr.Code()) }) } + +func TestRehydrateIPCAuthContextPreservesExistingJWK(t *testing.T) { + testLogger := logger.CreateTestLogger() + mockJWT, err := jwt.NewBuilder().Subject("rehydrated-user").Build() + require.NoError(t, err) + rawToken, err := jwt.Sign(mockJWT, jwt.WithInsecureNoSignature()) + require.NoError(t, err) + mockJWK, err := jwk.FromRaw([]byte("existing-jwk")) + require.NoError(t, err) + + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs(ctxAuth.AccessTokenKey, string(rawToken))) + ctx = ctxAuth.ContextWithAuthNInfo(ctx, mockJWK, nil, "") + + rehydratedCtx, err := rehydrateIPCAuthContext(ctx, testLogger) + require.NoError(t, err) + require.Same(t, mockJWK, ctxAuth.GetJWKFromContext(rehydratedCtx, testLogger)) + + retrievedJWT := ctxAuth.GetAccessTokenFromContext(rehydratedCtx, testLogger) + require.NotNil(t, retrievedJWT) + assert.Equal(t, "rehydrated-user", retrievedJWT.Subject()) + assert.Equal(t, string(rawToken), ctxAuth.GetRawAccessTokenFromContext(rehydratedCtx, testLogger)) +} diff --git a/service/internal/dotnotation/dotnotation.go b/service/internal/dotnotation/dotnotation.go new file mode 100644 index 0000000000..0992f60dd6 --- /dev/null +++ b/service/internal/dotnotation/dotnotation.go @@ -0,0 +1,86 @@ +package dotnotation + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +// Get retrieves a value from a nested map using dot notation keys. +func Get(m map[string]any, key string) any { + keys := strings.Split(key, ".") + for i, k := range keys { + if i == len(keys)-1 { + return m[k] + } + if m[k] == nil { + return nil + } + var ok bool + m, ok = toMap(m[k]) + if !ok { + return nil + } + } + return nil +} + +// Set stores a value in a nested map using dot notation keys, creating +// intermediate maps as needed. +func Set(m map[string]any, key string, value any) error { + if m == nil { + return errors.New("nil root map") + } + if key == "" { + return errors.New("empty path") + } + + keys := strings.Split(key, ".") + for _, segment := range keys { + if segment == "" { + return fmt.Errorf("invalid path %q: empty segment", key) + } + } + current := m + for i, k := range keys[:len(keys)-1] { + next, exists := current[k] + if !exists || next == nil { + child := map[string]any{} + current[k] = child + current = child + continue + } + + child, ok := toMap(next) + if !ok { + return fmt.Errorf("path collision at %s", strings.Join(keys[:i+1], ".")) + } + current[k] = child + current = child + } + + current[keys[len(keys)-1]] = value + return nil +} + +func toMap(value any) (map[string]any, bool) { + if value == nil { + return nil, false + } + if typed, ok := value.(map[string]any); ok { + return typed, true + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Map || rv.Type().Key().Kind() != reflect.String { + return nil, false + } + + out := make(map[string]any, rv.Len()) + iter := rv.MapRange() + for iter.Next() { + out[iter.Key().String()] = iter.Value().Interface() + } + return out, true +} diff --git a/service/internal/dotnotation/dotnotation_test.go b/service/internal/dotnotation/dotnotation_test.go new file mode 100644 index 0000000000..9a9ba5cd59 --- /dev/null +++ b/service/internal/dotnotation/dotnotation_test.go @@ -0,0 +1,112 @@ +package dotnotation + +import "testing" + +func TestGet(t *testing.T) { + tests := []struct { + name string + input map[string]any + key string + expected any + }{ + {name: "valid key", input: map[string]any{"a": map[string]any{"b": 1}}, key: "a.b", expected: 1}, + {name: "non-existent key", input: map[string]any{"a": map[string]any{"b": 1}}, key: "a.c", expected: nil}, + {name: "nested map", input: map[string]any{"a": map[string]any{"b": map[string]any{"c": 2}}}, key: "a.b.c", expected: 2}, + {name: "map string string", input: map[string]any{"a": map[string]string{"b": "value"}}, key: "a.b", expected: "value"}, + {name: "invalid key type", input: map[string]any{"a": 1}, key: "a.b", expected: nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Get(tt.input, tt.key) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestSet(t *testing.T) { + t.Run("creates nested maps", func(t *testing.T) { + input := map[string]any{} + + err := Set(input, "eventMetaData.requester.sub", "test-user") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got := Get(input, "eventMetaData.requester.sub"); got != "test-user" { + t.Fatalf("expected nested value, got %v", got) + } + }) + + t.Run("converts string keyed maps", func(t *testing.T) { + input := map[string]any{ + "eventMetaData": map[string]string{ + "existing": "value", + }, + } + + err := Set(input, "eventMetaData.requester.sub", "test-user") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got := Get(input, "eventMetaData.existing"); got != "value" { + t.Fatalf("expected existing value, got %v", got) + } + if got := Get(input, "eventMetaData.requester.sub"); got != "test-user" { + t.Fatalf("expected nested value, got %v", got) + } + }) + + t.Run("fails on non-map collision", func(t *testing.T) { + input := map[string]any{ + "eventMetaData": "not-a-map", + } + + err := Set(input, "eventMetaData.requester.sub", "test-user") + if err == nil { + t.Fatal("expected collision error") + } + }) + + t.Run("fails on nil root map", func(t *testing.T) { + defer func() { + if recovered := recover(); recovered != nil { + t.Fatalf("Set should not panic, got %v", recovered) + } + }() + + if err := Set(nil, "a.b", "value"); err == nil { + t.Fatal("expected error for nil root map") + } + }) + + t.Run("fails on malformed paths", func(t *testing.T) { + for _, path := range []string{"a..b", ".a", "a."} { + t.Run(path, func(t *testing.T) { + defer func() { + if recovered := recover(); recovered != nil { + t.Fatalf("Set should not panic, got %v", recovered) + } + }() + + input := make(map[string]any) + if err := Set(input, path, "value"); err == nil { + t.Fatal("expected malformed path error") + } + + if got := Get(input, "a"); got != nil { + t.Fatalf("expected no root value for a, got %v", got) + } + if got := Get(input, "a."); got != nil { + t.Fatalf("expected no nested invalid value, got %v", got) + } + if _, exists := input[""]; exists { + t.Fatal("expected no empty-string root key") + } + }) + } + }) +} diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 8d16b784f7..349c911fd2 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -542,7 +542,7 @@ func newConnectRPC(c Config, authInt connect.Interceptor, ints []connect.Interce // Add protovalidate interceptor validationInterceptor := validate.NewInterceptor() - interceptors = append(interceptors, connect.WithInterceptors(validationInterceptor, audit.ContextServerInterceptor(logger.Logger))) + interceptors = append(interceptors, connect.WithInterceptors(validationInterceptor, audit.ContextServerInterceptor(logger.Audit))) // Add any additional interceptors provided programmatically AFTER the default ones, so they have access needed context if len(ints) > 0 { diff --git a/service/logger/audit/config.go b/service/logger/audit/config.go new file mode 100644 index 0000000000..220fa1f971 --- /dev/null +++ b/service/logger/audit/config.go @@ -0,0 +1,73 @@ +package audit + +import ( + "fmt" + "strings" +) + +// JWTClaimMapping maps a JWT claim to a destination in the emitted audit log. +// The destination path uses dot notation over the audit log output shape, e.g. +// `eventMetaData.requester.sub`. +type JWTClaimMapping struct { + Claim string `mapstructure:"claim" json:"claim"` + Path string `mapstructure:"path" json:"path"` +} + +// Config contains platform-wide audit configuration. +type Config struct { + // JWTClaimMappings writes JWT claims to configured destinations in the audit + // log output, preserving the original JSON value types where possible. + JWTClaimMappings []JWTClaimMapping `mapstructure:"jwt_claim_mappings" json:"jwt_claim_mappings"` +} + +func (c Config) Validate() error { + for idx, mapping := range c.JWTClaimMappings { + if mapping.Claim == "" { + return fmt.Errorf("jwt_claim_mappings[%d].claim is required", idx) + } + if mapping.Path == "" { + return fmt.Errorf("jwt_claim_mappings[%d].path is required", idx) + } + if err := validateClaimDestinationPath(mapping.Path); err != nil { + return fmt.Errorf("jwt_claim_mappings[%d].path: %w", idx, err) + } + } + if err := validateNoOverlappingPaths(c.JWTClaimMappings); err != nil { + return err + } + return nil +} + +func validateNoOverlappingPaths(mappings []JWTClaimMapping) error { + for i, a := range mappings { + aParts := strings.Split(a.Path, ".") + for j, b := range mappings { + if i >= j { + continue + } + if a.Path == b.Path { + return fmt.Errorf("%w: duplicate path %q", ErrOverlappingAuditPaths, a.Path) + } + bParts := strings.Split(b.Path, ".") + if isPathPrefix(aParts, bParts) { + return fmt.Errorf("%w: %q is a prefix of %q", ErrOverlappingAuditPaths, a.Path, b.Path) + } + if isPathPrefix(bParts, aParts) { + return fmt.Errorf("%w: %q is a prefix of %q", ErrOverlappingAuditPaths, b.Path, a.Path) + } + } + } + return nil +} + +func isPathPrefix(short, long []string) bool { + if len(short) >= len(long) { + return false + } + for i, s := range short { + if s != long[i] { + return false + } + } + return true +} diff --git a/service/logger/audit/contextServerInterceptor.go b/service/logger/audit/contextServerInterceptor.go index bbf69cf347..13a7032347 100644 --- a/service/logger/audit/contextServerInterceptor.go +++ b/service/logger/audit/contextServerInterceptor.go @@ -2,7 +2,6 @@ package audit import ( "context" - "log/slog" "net/http" "connectrpc.com/connect" @@ -13,7 +12,7 @@ import ( // ContextServerInterceptor allows audit events to track request state. // This is required for audit logging. -func ContextServerInterceptor(logger *slog.Logger) connect.UnaryInterceptorFunc { +func ContextServerInterceptor(logger *Logger) connect.UnaryInterceptorFunc { interceptor := func(next connect.UnaryFunc) connect.UnaryFunc { return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { // Get metadata from the context diff --git a/service/logger/audit/enrichment.go b/service/logger/audit/enrichment.go new file mode 100644 index 0000000000..ee53785050 --- /dev/null +++ b/service/logger/audit/enrichment.go @@ -0,0 +1,153 @@ +package audit + +import ( + "context" + "encoding" + "encoding/json" + "log/slog" + "reflect" + + dotnotation "github.com/opentdf/platform/service/internal/dotnotation" + ctxAuth "github.com/opentdf/platform/service/pkg/auth" +) + +func (a *Logger) buildLogEntry(ctx context.Context, event *EventObject) map[string]any { + entry := event.emittedPayloadMap() + a.applyJWTClaimEnrichment(ctx, entry) + return entry +} + +func (a *Logger) applyJWTClaimEnrichment(ctx context.Context, entry map[string]any) { + if len(a.config.JWTClaimMappings) == 0 { + return + } + + token := ctxAuth.GetAccessTokenFromContext(ctx, a.logger) + if token == nil { + return + } + + claimsMap, err := token.AsMap(ctx) + if err != nil { + a.logger.ErrorContext(ctx, "failed to read JWT claims for audit enrichment", slog.Any("error", err)) + return + } + + a.applyMappedJWTClaims(ctx, entry, claimsMap) +} + +func (a *Logger) applyMappedJWTClaims(ctx context.Context, entry map[string]any, claimsMap map[string]any) { + for _, mapping := range a.config.JWTClaimMappings { + if mapping.Claim == "" || mapping.Path == "" { + continue + } + + value := dotnotation.Get(claimsMap, mapping.Claim) + if value == nil { + continue + } + + if err := dotnotation.Set(entry, mapping.Path, normalizeAuditValue(value)); err != nil { + a.logger.ErrorContext(ctx, + "failed to apply JWT claim mapping to audit log", + slog.String("claim", mapping.Claim), + slog.String("path", mapping.Path), + slog.Any("error", err), + ) + } + } +} + +func normalizeAuditValue(value any) any { + switch typed := value.(type) { + case nil: + return nil + case map[string]any: + if typed == nil { + return nil + } + normalized := make(map[string]any, len(typed)) + for key, nested := range typed { + normalized[key] = normalizeAuditValue(nested) + } + return normalized + case []any: + if typed == nil { + return nil + } + normalized := make([]any, len(typed)) + for idx, nested := range typed { + normalized[idx] = normalizeAuditValue(nested) + } + return normalized + } + + if marshaler, ok := value.(json.Marshaler); ok { + encoded, err := marshaler.MarshalJSON() + if err == nil { + var decoded any + if err := json.Unmarshal(encoded, &decoded); err == nil { + return decoded + } + } + } + + if marshaler, ok := value.(encoding.TextMarshaler); ok { + encoded, err := marshaler.MarshalText() + if err == nil { + return string(encoded) + } + } + + rv := reflect.ValueOf(value) + if !rv.IsValid() { + return nil + } + //nolint:exhaustive // only composite kinds need normalization; scalars can pass through unchanged + switch rv.Kind() { + case reflect.Pointer: + if rv.IsNil() { + return nil + } + return normalizeAuditValue(rv.Elem().Interface()) + case reflect.Struct: + structType := rv.Type() + normalized := make(map[string]any, structType.NumField()) + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + if !field.IsExported() { + continue + } + opts, _ := parseAuditFieldOptions(field) + if opts.name == "" { + continue + } + normalized[opts.name] = normalizeAuditValue(rv.Field(i).Interface()) + } + return normalized + case reflect.Map: + if rv.IsNil() { + return nil + } + if rv.Type().Key().Kind() != reflect.String { + return value + } + normalized := make(map[string]any, rv.Len()) + iter := rv.MapRange() + for iter.Next() { + normalized[iter.Key().String()] = normalizeAuditValue(iter.Value().Interface()) + } + return normalized + case reflect.Slice, reflect.Array: + if rv.Kind() == reflect.Slice && rv.IsNil() { + return nil + } + normalized := make([]any, rv.Len()) + for idx := range normalized { + normalized[idx] = normalizeAuditValue(rv.Index(idx).Interface()) + } + return normalized + default: + return value + } +} diff --git a/service/logger/audit/logger.go b/service/logger/audit/logger.go index 56060c6f56..5ca12509b9 100644 --- a/service/logger/audit/logger.go +++ b/service/logger/audit/logger.go @@ -34,6 +34,7 @@ var logLevelNames = map[slog.Leveler]string{ type Logger struct { logger *slog.Logger + config Config } // Used to support custom log levels showing up with custom labels as well @@ -61,10 +62,26 @@ func CreateAuditLogger(logger slog.Logger) *Logger { } } +func cloneConfig(cfg Config) Config { + cloned := cfg + cloned.JWTClaimMappings = append([]JWTClaimMapping(nil), cfg.JWTClaimMappings...) + return cloned +} + +// ApplyConfig validates and stores the latest audit enrichment configuration. +func (a *Logger) ApplyConfig(cfg Config) error { + if err := cfg.Validate(); err != nil { + return err + } + a.config = cloneConfig(cfg) + return nil +} + func (a *Logger) With(key string, value string) *Logger { return &Logger{ //nolint:sloglint // custom logger should support key/value pairs in With attributes logger: a.logger.With(key, value), + config: cloneConfig(a.config), } } @@ -81,7 +98,7 @@ func (tx *auditTransaction) addEvent(verb Verb, event *EventObject) { // logClose completes an audit transaction and emits all recorded events. // If success is false or err is not nil, events are logged as "cancelled" with the error attached. // Otherwise, events are logged with their originally recorded success/failure status. -func (tx *auditTransaction) logClose(ctx context.Context, logger *slog.Logger, success bool, err error) { +func (tx *auditTransaction) logClose(ctx context.Context, auditLogger *Logger, success bool, err error) { tx.mu.Lock() defer tx.mu.Unlock() for _, event := range tx.events { @@ -99,7 +116,7 @@ func (tx *auditTransaction) logClose(ctx context.Context, logger *slog.Logger, s } //nolint:sloglint // audit message is always just the verb - logger.Log(ctx, LevelAudit, string(event.verb), slog.Any("audit", *auditEvent)) + auditLogger.logger.Log(ctx, LevelAudit, string(event.verb), slog.Any("audit", auditLogger.buildLogEntry(ctx, auditEvent))) } } diff --git a/service/logger/audit/logger_test.go b/service/logger/audit/logger_test.go index c254f6dfe4..c2000928c2 100644 --- a/service/logger/audit/logger_test.go +++ b/service/logger/audit/logger_test.go @@ -11,7 +11,9 @@ import ( "time" "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/protocol/go/authorization" + ctxAuth "github.com/opentdf/platform/service/pkg/auth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -91,21 +93,24 @@ func extractLogEntry(t *testing.T, logBuffer *bytes.Buffer) (logEntryStructure, return entry, entryTime } -func doWithLogger(t *testing.T, testFunc func(ctx context.Context, l *Logger)) (ls logEntryStructure, lt time.Time) { //nolint:nonamedreturns // Required to rewrite on panics, right? - l, buf := createTestLogger() +func doWithLogger(t *testing.T, contextSetup func(context.Context) context.Context, testFunc func(ctx context.Context, l *Logger)) (ls logEntryStructure, lt time.Time) { //nolint:nonamedreturns // Named returns let the deferred recover path populate the extracted audit log. ctx := createTestContext(t) + if contextSetup != nil { + ctx = contextSetup(ctx) + } + l, buf := createTestLogger() tx, ok := ctx.Value(contextKey{}).(*auditTransaction) require.True(t, ok, "audit transaction missing from context") defer func() { if r := recover(); r != nil { if err, okerr := r.(error); okerr { - tx.logClose(ctx, l.logger, false, err) + tx.logClose(ctx, l, false, err) } else { - tx.logClose(ctx, l.logger, false, nil) + tx.logClose(ctx, l, false, nil) } } else { - tx.logClose(ctx, l.logger, true, nil) + tx.logClose(ctx, l, true, nil) } ls, lt = extractLogEntry(t, buf) }() @@ -115,8 +120,32 @@ func doWithLogger(t *testing.T, testFunc func(ctx context.Context, l *Logger)) ( return ls, lt } +func createTestJWTForAudit(t *testing.T) (jwt.Token, string) { + t.Helper() + + token, err := jwt.NewBuilder(). + Subject("jwt-user"). + Claim("realm_access", map[string]any{"roles": []string{"admin", "user"}}). + Claim("email_verified", true). + Build() + require.NoError(t, err) + + rawToken, err := jwt.Sign(token, jwt.WithInsecureNoSignature()) + require.NoError(t, err) + + return token, string(rawToken) +} + +func decodeAuditPayload(t *testing.T, payload json.RawMessage) map[string]any { + t.Helper() + + var decoded map[string]any + require.NoError(t, json.Unmarshal(payload, &decoded)) + return decoded +} + func TestAuditRewrapSuccess(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.RewrapSuccess(ctx, rewrapParams) }) @@ -175,7 +204,7 @@ func TestAuditRewrapSuccess(t *testing.T) { } func TestAuditRewrapFailure(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.RewrapFailure(ctx, rewrapParams) }) @@ -234,7 +263,7 @@ func TestAuditRewrapFailure(t *testing.T) { } func TestPolicyCRUDSuccess(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.PolicyCRUDSuccess(ctx, policyCRUDParams) }) @@ -284,7 +313,7 @@ func TestPolicyCRUDSuccess(t *testing.T) { } func TestPolicyCrudFailure(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.PolicyCRUDFailure(ctx, policyCRUDParams) }) @@ -333,8 +362,187 @@ func TestPolicyCrudFailure(t *testing.T) { assert.JSONEq(t, expectedAuditLog, loggedMessage) } +func TestAuditJWTClaimMappingsApplyToPolicyAudit(t *testing.T) { + token, rawToken := createTestJWTForAudit(t) + + logEntry, _ := doWithLogger(t, func(ctx context.Context) context.Context { + return ctxAuth.ContextWithAuthNInfo(ctx, nil, token, rawToken) + }, func(ctx context.Context, l *Logger) { + require.NoError(t, l.ApplyConfig(Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.requester.sub"}, + {Claim: "realm_access.roles", Path: "eventMetaData.requester.roles"}, + {Claim: "email_verified", Path: "eventMetaData.requester.emailVerified"}, + }, + })) + l.PolicyCRUDSuccess(ctx, policyCRUDParams) + }) + + payload := decodeAuditPayload(t, logEntry.Audit) + eventMetaData, ok := payload["eventMetaData"].(map[string]any) + require.True(t, ok) + requester, ok := eventMetaData["requester"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "jwt-user", requester["sub"]) + assert.Equal(t, []any{"admin", "user"}, requester["roles"]) + assert.Equal(t, true, requester["emailVerified"]) +} + +func TestAuditJWTClaimMappingsCanWriteToEntityMetadata(t *testing.T) { + token, rawToken := createTestJWTForAudit(t) + + logEntry, _ := doWithLogger(t, func(ctx context.Context) context.Context { + return ctxAuth.ContextWithAuthNInfo(ctx, nil, token, rawToken) + }, func(ctx context.Context, l *Logger) { + require.NoError(t, l.ApplyConfig(Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.entityMetadata.sub"}, + {Claim: "realm_access.roles", Path: "eventMetaData.entityMetadata.roles"}, + {Claim: "email_verified", Path: "eventMetaData.entityMetadata.emailVerified"}, + }, + })) + l.PolicyCRUDSuccess(ctx, policyCRUDParams) + }) + + payload := decodeAuditPayload(t, logEntry.Audit) + eventMetaData, ok := payload["eventMetaData"].(map[string]any) + require.True(t, ok) + entityMetadata, ok := eventMetaData["entityMetadata"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "jwt-user", entityMetadata["sub"]) + assert.Equal(t, []any{"admin", "user"}, entityMetadata["roles"]) + assert.Equal(t, true, entityMetadata["emailVerified"]) +} + +func TestAuditJWTClaimMappingsCoverNamedAndUnnamedPaths(t *testing.T) { + token, rawToken := createTestJWTForAudit(t) + + logEntry, _ := doWithLogger(t, func(ctx context.Context) context.Context { + return ctxAuth.ContextWithAuthNInfo(ctx, nil, token, rawToken) + }, func(ctx context.Context, l *Logger) { + require.NoError(t, l.ApplyConfig(Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "object.name"}, + {Claim: "realm_access.roles", Path: "actor.attributes"}, + {Claim: "sub", Path: "original.request.jwt.sub"}, + {Claim: "sub", Path: "banana"}, + {Claim: "email_verified", Path: "kiwi.requester.emailVerified"}, + }, + })) + l.PolicyCRUDSuccess(ctx, policyCRUDParams) + }) + + payload := decodeAuditPayload(t, logEntry.Audit) + + object, ok := payload["object"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "jwt-user", object["name"]) + + actor, ok := payload["actor"].(map[string]any) + require.True(t, ok) + assert.Equal(t, []any{"admin", "user"}, actor["attributes"]) + + original, ok := payload["original"].(map[string]any) + require.True(t, ok) + request, ok := original["request"].(map[string]any) + require.True(t, ok) + jwt, ok := request["jwt"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "jwt-user", jwt["sub"]) + + assert.Equal(t, "jwt-user", payload["banana"]) + + kiwi, ok := payload["kiwi"].(map[string]any) + require.True(t, ok) + requester, ok := kiwi["requester"].(map[string]any) + require.True(t, ok) + assert.Equal(t, true, requester["emailVerified"]) +} + +func TestAuditJWTClaimMappingsLeaveReservedFieldsUntouched(t *testing.T) { + token, rawToken := createTestJWTForAudit(t) + + logEntry, _ := doWithLogger(t, func(ctx context.Context) context.Context { + return ctxAuth.ContextWithAuthNInfo(ctx, nil, token, rawToken) + }, func(ctx context.Context, l *Logger) { + require.NoError(t, l.ApplyConfig(Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.requester.sub"}, + }, + })) + l.PolicyCRUDSuccess(ctx, policyCRUDParams) + }) + + payload := decodeAuditPayload(t, logEntry.Audit) + assert.Equal(t, TestRequestID.String(), payload["requestID"]) + + eventMetaData, ok := payload["eventMetaData"].(map[string]any) + require.True(t, ok) + requester, ok := eventMetaData["requester"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "jwt-user", requester["sub"]) +} + +func TestAuditApplyConfigRejectsReservedPaths(t *testing.T) { + t.Run("requestID", func(t *testing.T) { + assertReservedAuditPathRejected(t, "requestID") + }) + + t.Run("clientInfo.userAgent", func(t *testing.T) { + assertReservedAuditPathRejected(t, "clientInfo.userAgent") + }) + + t.Run("clientInfo.requestIP", func(t *testing.T) { + assertReservedAuditPathRejected(t, "clientInfo.requestIP") + }) +} + +func TestAuditApplyConfigClonesMappings(t *testing.T) { + l, _ := createTestLogger() + cfg := Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.requester.sub"}, + }, + } + + require.NoError(t, l.ApplyConfig(cfg)) + + cfg.JWTClaimMappings[0].Path = "eventMetaData.requester.changed" + + require.Equal(t, "eventMetaData.requester.sub", l.config.JWTClaimMappings[0].Path) +} + +func TestAuditLoggerWithClonesMappings(t *testing.T) { + l, _ := createTestLogger() + require.NoError(t, l.ApplyConfig(Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.requester.sub"}, + }, + })) + + child := l.With("namespace", "policy") + l.config.JWTClaimMappings[0].Path = "eventMetaData.requester.changed" + + require.Equal(t, "eventMetaData.requester.sub", child.config.JWTClaimMappings[0].Path) +} + +func assertReservedAuditPathRejected(t *testing.T, path string) { + t.Helper() + + l, _ := createTestLogger() + err := l.ApplyConfig(Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: path}, + }, + }) + + require.Error(t, err) + require.ErrorIs(t, err, ErrReservedAuditPath) + require.ErrorContains(t, err, "jwt_claim_mappings[0].path") +} + func TestDeferredRewrapSuccess(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.RewrapSuccess(ctx, rewrapParams) }) @@ -393,7 +601,7 @@ func TestDeferredRewrapSuccess(t *testing.T) { } func TestDeferredRewrapCancelled(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.RewrapSuccess(ctx, rewrapParams) panic(errors.New("operation failed")) }) @@ -455,7 +663,7 @@ func TestDeferredRewrapCancelled(t *testing.T) { } func TestDeferredPolicyCRUDSuccess(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.PolicyCRUDSuccess(ctx, policyCRUDParams) }) @@ -518,7 +726,7 @@ func TestGetDecision(t *testing.T) { FQNs: []string{"test-fqn"}, } - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.GetDecision(ctx, params) }) expectedAuditLog := fmt.Sprintf( diff --git a/service/logger/audit/schema.go b/service/logger/audit/schema.go new file mode 100644 index 0000000000..7dee297e7d --- /dev/null +++ b/service/logger/audit/schema.go @@ -0,0 +1,201 @@ +package audit + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +var ( + // ErrReservedAuditPath indicates a claim destination targets a protected audit field. + ErrReservedAuditPath = errors.New("reserved audit path") + // ErrUnknownAuditPath indicates a claim destination traverses an unknown closed-schema path. + ErrUnknownAuditPath = errors.New("unknown audit path") + // ErrAuditContainerPath indicates a claim destination resolves to a container instead of a writable leaf. + ErrAuditContainerPath = errors.New("audit path resolves to a container") + // ErrOverlappingAuditPaths indicates two claim destinations conflict because one is a prefix of the other. + ErrOverlappingAuditPaths = errors.New("overlapping audit paths") + + auditClaimDestinationSchema = mustBuildAuditPathSchema(reflect.TypeOf(EventObject{})) +) + +type auditFieldOptions struct { + name string + reserved bool + extensible bool +} + +type auditPathSchema struct { + children map[string]*auditPathSchema + isLeaf bool + reserved bool + extensible bool +} + +func mustBuildAuditPathSchema(rootType reflect.Type) *auditPathSchema { + root, err := buildAuditPathSchema(rootType) + if err != nil { + panic(err) + } + return root +} + +func buildAuditPathSchema(rootType reflect.Type) (*auditPathSchema, error) { + rootType = indirectType(rootType) + if rootType.Kind() != reflect.Struct { + return nil, fmt.Errorf("audit schema root must be a struct, got %s", rootType.Kind()) + } + + root := &auditPathSchema{ + children: make(map[string]*auditPathSchema), + extensible: true, + } + if err := addAuditSchemaFields(root, rootType); err != nil { + return nil, err + } + return root, nil +} + +func addAuditSchemaFields(parent *auditPathSchema, structType reflect.Type) error { + for i := range structType.NumField() { + field := structType.Field(i) + if !field.IsExported() { + continue + } + + opts, err := parseAuditFieldOptions(field) + if err != nil { + return err + } + if opts.name == "" { + continue + } + if _, exists := parent.children[opts.name]; exists { + return fmt.Errorf("duplicate audit schema path %q on %s", opts.name, structType) + } + + child := &auditPathSchema{ + children: make(map[string]*auditPathSchema), + isLeaf: isWritableAuditLeaf(indirectType(field.Type)), + reserved: opts.reserved, + extensible: opts.extensible, + } + parent.children[opts.name] = child + + fieldType := indirectType(field.Type) + if fieldType.Kind() == reflect.Struct { + if err := addAuditSchemaFields(child, fieldType); err != nil { + return err + } + } + if len(child.children) == 0 { + child.children = nil + } + } + + return nil +} + +func parseAuditFieldOptions(field reflect.StructField) (auditFieldOptions, error) { + tag := field.Tag.Get("audit") + if tag == "-" { + return auditFieldOptions{}, nil + } + + reserved, extensible, err := parseAuditTag(tag) + if err != nil { + return auditFieldOptions{}, fmt.Errorf("field %s: %w", field.Name, err) + } + name := parseJSONFieldName(field) + if name == "" { + return auditFieldOptions{}, nil + } + + return auditFieldOptions{ + name: name, + reserved: reserved, + extensible: extensible, + }, nil +} + +func parseAuditTag(tag string) (bool, bool, error) { + switch tag { + case "": + return false, false, nil + case "reserved": + return true, false, nil + case "extensible": + return false, true, nil + default: + return false, false, fmt.Errorf("unknown audit tag %q", tag) + } +} + +func parseJSONFieldName(field reflect.StructField) string { + tag := field.Tag.Get("json") + if tag == "-" { + return "" + } + if tag == "" { + return field.Name + } + + name, _, _ := strings.Cut(tag, ",") + if name == "" { + return field.Name + } + return name +} + +func indirectType(t reflect.Type) reflect.Type { + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + return t +} + +func isWritableAuditLeaf(t reflect.Type) bool { + kind := indirectType(t).Kind() + return kind != reflect.Struct && kind != reflect.Map +} + +func validateClaimDestinationPath(path string) error { + if path == "" { + return fmt.Errorf("%w: empty path", ErrUnknownAuditPath) + } + + segments := strings.Split(path, ".") + for _, segment := range segments { + if segment == "" { + return fmt.Errorf("%w: %s", ErrUnknownAuditPath, path) + } + } + + current := auditClaimDestinationSchema + for idx, segment := range segments { + child, ok := current.children[segment] + if !ok { + if current.extensible { + return nil + } + return fmt.Errorf("%w: %s", ErrUnknownAuditPath, path) + } + + isLast := idx == len(segments)-1 + if isLast { + switch { + case child.reserved: + return fmt.Errorf("%w: %s", ErrReservedAuditPath, path) + case child.extensible || !child.isLeaf: + return fmt.Errorf("%w: %s", ErrAuditContainerPath, path) + default: + return nil + } + } + + current = child + } + + return nil +} diff --git a/service/logger/audit/schema_test.go b/service/logger/audit/schema_test.go new file mode 100644 index 0000000000..06bded902a --- /dev/null +++ b/service/logger/audit/schema_test.go @@ -0,0 +1,129 @@ +package audit + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidateClaimDestinationPath(t *testing.T) { + t.Run("allows writable leaf paths", func(t *testing.T) { + require.NoError(t, validateClaimDestinationPath("object.id")) + require.NoError(t, validateClaimDestinationPath("actor.attributes")) + }) + + t.Run("allows nested paths below extensible maps", func(t *testing.T) { + require.NoError(t, validateClaimDestinationPath("eventMetaData.requester.sub")) + require.NoError(t, validateClaimDestinationPath("original.request.headers.user")) + }) + + t.Run("allows top level additions", func(t *testing.T) { + require.NoError(t, validateClaimDestinationPath("banana")) + require.NoError(t, validateClaimDestinationPath("banana.requester.sub")) + }) + + t.Run("rejects reserved paths", func(t *testing.T) { + err := validateClaimDestinationPath("requestID") + require.ErrorIs(t, err, ErrReservedAuditPath) + + err = validateClaimDestinationPath("action.result") + require.ErrorIs(t, err, ErrReservedAuditPath) + + err = validateClaimDestinationPath("clientInfo.userAgent") + require.ErrorIs(t, err, ErrReservedAuditPath) + + err = validateClaimDestinationPath("clientInfo.requestIP") + require.ErrorIs(t, err, ErrReservedAuditPath) + }) + + t.Run("rejects container paths", func(t *testing.T) { + err := validateClaimDestinationPath("eventMetaData") + require.ErrorIs(t, err, ErrAuditContainerPath) + + err = validateClaimDestinationPath("object") + require.ErrorIs(t, err, ErrAuditContainerPath) + + err = validateClaimDestinationPath("object.attributes") + require.ErrorIs(t, err, ErrAuditContainerPath) + }) + + t.Run("rejects unknown nested paths below closed containers", func(t *testing.T) { + err := validateClaimDestinationPath("object.extra.foo") + require.ErrorIs(t, err, ErrUnknownAuditPath) + }) + + t.Run("rejects malformed dot paths", func(t *testing.T) { + for _, path := range []string{ + ".banana", // leading dot + "banana.", // trailing dot + "banana..mango", // consecutive dots + ".", // single dot + "..", // double dot + "a.b..c.d", // mid-path empty segment + ".a.b", // leading dot with valid tail + "eventMetaData.", // trailing dot after known node + } { + t.Run(path, func(t *testing.T) { + err := validateClaimDestinationPath(path) + require.ErrorIs(t, err, ErrUnknownAuditPath) + }) + } + }) + + t.Run("rejects empty path", func(t *testing.T) { + err := validateClaimDestinationPath("") + require.ErrorIs(t, err, ErrUnknownAuditPath) + }) +} + +func TestValidateNoOverlappingPaths(t *testing.T) { + t.Run("allows sibling paths", func(t *testing.T) { + err := validateNoOverlappingPaths([]JWTClaimMapping{ + {Claim: "sub", Path: "banana.kiwi"}, + {Claim: "email", Path: "banana.mango"}, + }) + require.NoError(t, err) + }) + + t.Run("rejects short prefix of long", func(t *testing.T) { + err := validateNoOverlappingPaths([]JWTClaimMapping{ + {Claim: "sub", Path: "banana"}, + {Claim: "email", Path: "banana.kiwi.mango"}, + }) + require.ErrorIs(t, err, ErrOverlappingAuditPaths) + }) + + t.Run("rejects long prefix of short", func(t *testing.T) { + err := validateNoOverlappingPaths([]JWTClaimMapping{ + {Claim: "email", Path: "banana.kiwi.mango"}, + {Claim: "sub", Path: "banana"}, + }) + require.ErrorIs(t, err, ErrOverlappingAuditPaths) + }) + + t.Run("allows identical depth different leaves", func(t *testing.T) { + err := validateNoOverlappingPaths([]JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.requester.sub"}, + {Claim: "email", Path: "eventMetaData.requester.email"}, + }) + require.NoError(t, err) + }) + + t.Run("rejects duplicate destination paths", func(t *testing.T) { + err := validateNoOverlappingPaths([]JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.requester.sub"}, + {Claim: "email", Path: "eventMetaData.requester.sub"}, + }) + require.ErrorIs(t, err, ErrOverlappingAuditPaths) + }) +} + +func TestBuildAuditPathSchemaRejectsUnknownTags(t *testing.T) { + type badStruct struct { + Field string `json:"field" audit:"resreved"` + } + _, err := buildAuditPathSchema(reflect.TypeOf(badStruct{})) + require.Error(t, err) + require.ErrorContains(t, err, "unknown audit tag") +} diff --git a/service/logger/audit/utils.go b/service/logger/audit/utils.go index 2d1bf12e88..743b598072 100644 --- a/service/logger/audit/utils.go +++ b/service/logger/audit/utils.go @@ -19,31 +19,30 @@ type EventObject struct { Object auditEventObject `json:"object"` Action eventAction `json:"action"` Actor auditEventActor `json:"actor"` - EventMetaData auditEventMetadata `json:"eventMetaData"` + EventMetaData auditEventMetadata `json:"eventMetaData" audit:"extensible"` ClientInfo eventClientInfo `json:"clientInfo"` - Original map[string]any `json:"original,omitempty"` - Updated map[string]any `json:"updated,omitempty"` - RequestID uuid.UUID `json:"requestId"` - Timestamp string `json:"timestamp"` + Original map[string]any `json:"original,omitempty" audit:"extensible"` + Updated map[string]any `json:"updated,omitempty" audit:"extensible"` + RequestID uuid.UUID `json:"requestID" audit:"reserved"` + Timestamp string `json:"timestamp" audit:"reserved"` } func (e EventObject) LogValue() slog.Value { - return slog.GroupValue( - slog.Any("object", e.Object), - slog.Any("action", e.Action), - slog.Any("actor", e.Actor), - slog.Any("eventMetaData", e.EventMetaData), - slog.Any("clientInfo", e.ClientInfo), - slog.Any("original", e.Original), - slog.Any("updated", e.Updated), - slog.String("requestID", e.RequestID.String()), - slog.String("timestamp", e.Timestamp)) + return slog.AnyValue(e.emittedPayloadMap()) +} + +func (e EventObject) emittedPayloadMap() map[string]any { + entry, ok := normalizeAuditValue(e).(map[string]any) + if !ok { + panic("normalized audit payload must be a map") + } + return entry } // event.object type auditEventObject struct { - Type ObjectType `json:"type"` + Type ObjectType `json:"type" audit:"reserved"` ID string `json:"id"` Name string `json:"name,omitempty"` Attributes eventObjectAttributes `json:"attributes,omitempty"` @@ -73,8 +72,8 @@ func (e eventObjectAttributes) LogValue() slog.Value { // event.action type eventAction struct { - Type ActionType `json:"type"` - Result ActionResult `json:"result"` + Type ActionType `json:"type" audit:"reserved"` + Result ActionResult `json:"result" audit:"reserved"` } func (e eventAction) LogValue() slog.Value { @@ -85,7 +84,7 @@ func (e eventAction) LogValue() slog.Value { // event.actor type auditEventActor struct { - ID string `json:"id"` + ID string `json:"id" audit:"reserved"` Attributes []any `json:"attributes"` } @@ -97,9 +96,9 @@ func (e auditEventActor) LogValue() slog.Value { // event.clientInfo type eventClientInfo struct { - UserAgent string `json:"userAgent"` - Platform string `json:"platform"` - RequestIP string `json:"requestIp"` + UserAgent string `json:"userAgent" audit:"reserved"` + Platform string `json:"platform" audit:"reserved"` + RequestIP string `json:"requestIP" audit:"reserved"` } func (e eventClientInfo) LogValue() slog.Value { diff --git a/service/logger/logger.go b/service/logger/logger.go index eec308ac5b..a96ec7a7a9 100644 --- a/service/logger/logger.go +++ b/service/logger/logger.go @@ -11,7 +11,7 @@ // Original map[string]interface{} `json:"original,omitempty"` // Updated map[string]interface{} `json:"updated,omitempty"` -// RequestID uuid.UUID `json:"requestId"` +// RequestID uuid.UUID `json:"requestID"` // Timestamp string `json:"timestamp"` // } diff --git a/service/pkg/auth/context_auth.go b/service/pkg/auth/context_auth.go index 91005e42b6..e6a55279f6 100644 --- a/service/pkg/auth/context_auth.go +++ b/service/pkg/auth/context_auth.go @@ -6,7 +6,6 @@ import ( "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" - "github.com/opentdf/platform/service/logger" "google.golang.org/grpc/metadata" ) @@ -30,6 +29,12 @@ type authContext struct { rawToken string } +// optionalErrorLogger keeps pkg/auth decoupled from the concrete logger package +// and helps avoid an import cycle. +type optionalErrorLogger interface { + ErrorContext(context.Context, string, ...any) +} + func ContextWithAuthNInfo(ctx context.Context, key jwk.Key, accessToken jwt.Token, raw string) context.Context { return context.WithValue(ctx, authnContextKey, &authContext{ key, @@ -38,7 +43,7 @@ func ContextWithAuthNInfo(ctx context.Context, key jwk.Key, accessToken jwt.Toke }) } -func getContextDetails(ctx context.Context, l *logger.Logger) *authContext { +func getContextDetails(ctx context.Context, l optionalErrorLogger) *authContext { key := ctx.Value(authnContextKey) if key == nil { return nil @@ -47,37 +52,43 @@ func getContextDetails(ctx context.Context, l *logger.Logger) *authContext { return c } - // We should probably return an error here? - l.ErrorContext(ctx, "invalid authContext") + if l != nil { + l.ErrorContext(ctx, "invalid authContext") + } return nil } -func GetJWKFromContext(ctx context.Context, l *logger.Logger) jwk.Key { +func GetJWKFromContext(ctx context.Context, l optionalErrorLogger) jwk.Key { if c := getContextDetails(ctx, l); c != nil { return c.key } return nil } -func GetAccessTokenFromContext(ctx context.Context, l *logger.Logger) jwt.Token { +func GetAccessTokenFromContext(ctx context.Context, l optionalErrorLogger) jwt.Token { if c := getContextDetails(ctx, l); c != nil { - return c.accessToken + if c.accessToken != nil { + return c.accessToken + } } return nil } -func GetRawAccessTokenFromContext(ctx context.Context, l *logger.Logger) string { +func GetRawAccessTokenFromContext(ctx context.Context, l optionalErrorLogger) string { if c := getContextDetails(ctx, l); c != nil { - return c.rawToken + if c.rawToken != "" { + return c.rawToken + } } return "" } // EnrichIncomingContextMetadataWithAuthn adds the access token and client ID to incoming context metadata // -// Adding the authn info to gRPC metadata propagates it across services rather than strictly -// in-process within Go alone -func EnrichIncomingContextMetadataWithAuthn(ctx context.Context, l *logger.Logger, clientID string) context.Context { +// This transports an already-authenticated token across internal hops. It does not +// validate the token again; the ConnectUnaryServerInterceptor auth middleware does that +// before ContextWithAuthNInfo is created. +func EnrichIncomingContextMetadataWithAuthn(ctx context.Context, l optionalErrorLogger, clientID string) context.Context { rawToken := GetRawAccessTokenFromContext(ctx, l) md, ok := metadata.FromIncomingContext(ctx) diff --git a/service/pkg/auth/context_auth_test.go b/service/pkg/auth/context_auth_test.go index 63354eb7f4..e427c5eb72 100644 --- a/service/pkg/auth/context_auth_test.go +++ b/service/pkg/auth/context_auth_test.go @@ -6,7 +6,6 @@ import ( "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" - "github.com/opentdf/platform/service/logger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc/metadata" @@ -38,7 +37,7 @@ func TestGetJWKFromContext(t *testing.T) { ctx := ContextWithAuthNInfo(t.Context(), mockJWK, nil, "") // Retrieve the JWK and assert - retrievedJWK := GetJWKFromContext(ctx, logger.CreateTestLogger()) + retrievedJWK := GetJWKFromContext(ctx, nil) assert.NotNil(t, retrievedJWK, "JWK should not be nil") assert.Equal(t, mockJWK, retrievedJWK, "Retrieved JWK should match the mock JWK") } @@ -49,7 +48,7 @@ func TestGetAccessTokenFromContext(t *testing.T) { ctx := ContextWithAuthNInfo(t.Context(), nil, mockJWT, "") // Retrieve the JWT and assert - retrievedJWT := GetAccessTokenFromContext(ctx, logger.CreateTestLogger()) + retrievedJWT := GetAccessTokenFromContext(ctx, nil) assert.NotNil(t, retrievedJWT, "Access token should not be nil") assert.Equal(t, mockJWT, retrievedJWT, "Retrieved JWT should match the mock JWT") } @@ -60,26 +59,53 @@ func TestGetRawAccessTokenFromContext(t *testing.T) { ctx := ContextWithAuthNInfo(t.Context(), nil, nil, rawToken) // Retrieve the raw token and assert - retrievedRawToken := GetRawAccessTokenFromContext(ctx, logger.CreateTestLogger()) + retrievedRawToken := GetRawAccessTokenFromContext(ctx, nil) assert.Equal(t, rawToken, retrievedRawToken, "Retrieved raw token should match the mock raw token") } +func TestGetRawAccessTokenFromContextDoesNotFallbackToMetadata(t *testing.T) { + t.Run("incoming access token metadata", func(t *testing.T) { + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs(AccessTokenKey, "incoming-token")) + retrievedRawToken := GetRawAccessTokenFromContext(ctx, nil) + assert.Empty(t, retrievedRawToken) + }) + + t.Run("outgoing authorization metadata", func(t *testing.T) { + ctx := metadata.NewOutgoingContext(t.Context(), metadata.Pairs("Authorization", "Bearer outgoing-token")) + retrievedRawToken := GetRawAccessTokenFromContext(ctx, nil) + assert.Empty(t, retrievedRawToken) + }) +} + +func TestGetAccessTokenFromContextDoesNotFallbackToMetadata(t *testing.T) { + mockJWT, err := jwt.NewBuilder(). + Subject("metadata-user"). + Claim("roles", []string{"admin"}). + Build() + require.NoError(t, err) + + rawToken, err := jwt.Sign(mockJWT, jwt.WithInsecureNoSignature()) + require.NoError(t, err) + + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs(AccessTokenKey, string(rawToken))) + retrievedJWT := GetAccessTokenFromContext(ctx, nil) + assert.Nil(t, retrievedJWT) +} + func TestGetContextDetailsInvalidType(t *testing.T) { // Create a context with an invalid type ctx := context.WithValue(t.Context(), authnContextKey, "invalidType") // Assert that GetJWKFromContext handles the invalid type correctly - retrievedJWK := GetJWKFromContext(ctx, logger.CreateTestLogger()) + retrievedJWK := GetJWKFromContext(ctx, nil) assert.Nil(t, retrievedJWK, "JWK should be nil when context value is invalid") } func TestEnrichIncomingContextMetadataWithAuthn(t *testing.T) { mockClientID := "test-client-id" - l := logger.CreateTestLogger() - t.Run("should add access token and client id to metadata", func(t *testing.T) { ctx := ContextWithAuthNInfo(t.Context(), nil, nil, "raw-token-string") - enrichedCtx := EnrichIncomingContextMetadataWithAuthn(ctx, l, mockClientID) + enrichedCtx := EnrichIncomingContextMetadataWithAuthn(ctx, nil, mockClientID) md, ok := metadata.FromIncomingContext(enrichedCtx) require.True(t, ok) @@ -95,7 +121,7 @@ func TestEnrichIncomingContextMetadataWithAuthn(t *testing.T) { t.Run("should not set client id if empty", func(t *testing.T) { ctx := ContextWithAuthNInfo(t.Context(), nil, nil, "raw-token-string") - enrichedCtx := EnrichIncomingContextMetadataWithAuthn(ctx, l, "") + enrichedCtx := EnrichIncomingContextMetadataWithAuthn(ctx, nil, "") md, ok := metadata.FromIncomingContext(enrichedCtx) require.True(t, ok) @@ -108,7 +134,7 @@ func TestEnrichIncomingContextMetadataWithAuthn(t *testing.T) { originalMD := metadata.New(map[string]string{"original-key": "original-value"}) ctx := metadata.NewIncomingContext(t.Context(), originalMD) ctx = ContextWithAuthNInfo(ctx, nil, nil, "raw-token-string") - enrichedCtx := EnrichIncomingContextMetadataWithAuthn(ctx, l, mockClientID) + enrichedCtx := EnrichIncomingContextMetadataWithAuthn(ctx, nil, mockClientID) md, ok := metadata.FromIncomingContext(enrichedCtx) require.True(t, ok) diff --git a/service/pkg/config/config.go b/service/pkg/config/config.go index aac79e7fbc..355cca081d 100644 --- a/service/pkg/config/config.go +++ b/service/pkg/config/config.go @@ -10,6 +10,7 @@ import ( "github.com/go-playground/validator/v10" "github.com/opentdf/platform/service/internal/server" "github.com/opentdf/platform/service/logger" + auditcfg "github.com/opentdf/platform/service/logger/audit" "github.com/opentdf/platform/service/pkg/db" "github.com/spf13/viper" ) @@ -76,6 +77,9 @@ type Config struct { // Logger represents the configuration settings for the logger. Logger logger.Config `mapstructure:"logger" json:"logger"` + // Audit represents the configuration settings for audit enrichment. + Audit auditcfg.Config `mapstructure:"audit" json:"audit"` + // Mode specifies which services to run. // By default, it runs all services. Mode []string `mapstructure:"mode" json:"mode" default:"[\"all\"]"` @@ -136,6 +140,7 @@ var ( func (c *Config) LogValue() slog.Value { return slog.GroupValue( slog.Bool("dev_mode", c.DevMode), + slog.Any("audit", c.Audit), slog.Any("db", c.DB), slog.Any("logger", c.Logger), slog.Any("mode", c.Mode), @@ -294,6 +299,9 @@ func (c *Config) Reload(ctx context.Context) error { if err := validator.New().Struct(c); err != nil { return errors.Join(err, ErrUnmarshallingConfig) } + if err := c.Audit.Validate(); err != nil { + return errors.Join(err, ErrUnmarshallingConfig) + } if skew := c.Security.ClockSkew(); skew > DefaultUnsafeClockSkew { slog.WarnContext(ctx, diff --git a/service/pkg/config/config_test.go b/service/pkg/config/config_test.go index 4ce51adc42..8b1a91917f 100644 --- a/service/pkg/config/config_test.go +++ b/service/pkg/config/config_test.go @@ -6,6 +6,7 @@ import ( "os" "testing" + auditcfg "github.com/opentdf/platform/service/logger/audit" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -356,6 +357,8 @@ func TestLoad_Precedence(t *testing.T) { setupLoaders func(t *testing.T, configFile string) []Loader envVars map[string]string err error + causeErr error + errContains string fileContent string asserts func(t *testing.T, cfg *Config) }{ @@ -388,15 +391,43 @@ server: port: 9090 logger: level: warn +audit: + jwt_claim_mappings: + - claim: email_verified + path: eventMetaData.requester.emailVerified `, asserts: func(t *testing.T, cfg *Config) { // Values from file assert.Equal(t, 9090, cfg.Server.Port) assert.Equal(t, "warn", cfg.Logger.Level) + assert.Equal(t, []auditcfg.JWTClaimMapping{ + {Claim: "email_verified", Path: "eventMetaData.requester.emailVerified"}, + }, cfg.Audit.JWTClaimMappings) // Value from defaults assert.Equal(t, []string{"all"}, cfg.Mode) }, }, + { + name: "top level audit mapping paths pass validation", + setupLoaders: func(t *testing.T, configFile string) []Loader { + file, err := NewConfigFileLoader(configKey, configFile) + require.NoError(t, err) + defaults, err := NewDefaultSettingsLoader() + require.NoError(t, err) + return []Loader{file, defaults} + }, + fileContent: ` +audit: + jwt_claim_mappings: + - claim: email_verified + path: banana.requester.emailVerified +`, + asserts: func(t *testing.T, cfg *Config) { + assert.Equal(t, []auditcfg.JWTClaimMapping{ + {Claim: "email_verified", Path: "banana.requester.emailVerified"}, + }, cfg.Audit.JWTClaimMappings) + }, + }, { name: "file with extras and defaults", setupLoaders: func(t *testing.T, configFile string) []Loader { @@ -426,6 +457,24 @@ special_key: assert.Equal(t, []string{"all"}, cfg.Mode) }, }, + { + name: "invalid audit mapping path fails validation", + setupLoaders: func(t *testing.T, configFile string) []Loader { + file, err := NewConfigFileLoader(configKey, configFile) + require.NoError(t, err) + defaults, err := NewDefaultSettingsLoader() + require.NoError(t, err) + return []Loader{file, defaults} + }, + err: ErrUnmarshallingConfig, + causeErr: auditcfg.ErrReservedAuditPath, + fileContent: ` +audit: + jwt_claim_mappings: + - claim: sub + path: requestID +`, + }, { name: "env overrides file and defaults except client_id", setupLoaders: func(t *testing.T, configFile string) []Loader { @@ -662,6 +711,14 @@ logger: // Assertions if tc.err != nil { require.Error(t, err) + require.ErrorIs(t, err, tc.err) + if tc.causeErr != nil { + require.ErrorIs(t, err, tc.causeErr) + } + if tc.errContains != "" { + require.ErrorContains(t, err, tc.errContains) + } + require.Nil(t, cfg) } else { require.NoError(t, err) require.NotNil(t, cfg) diff --git a/service/pkg/server/options.go b/service/pkg/server/options.go index db3952ceef..89c25725e7 100644 --- a/service/pkg/server/options.go +++ b/service/pkg/server/options.go @@ -83,8 +83,8 @@ func WithPublicRoutes(routes []string) StartOptions { } } -// WithIPCReauthRoutes option sets the IPC reauthorization routes for the server. -// It enables the server to reauthorize IPC routes and embed the token on the context. +// WithIPCReauthRoutes sets the IPC routes that should be fully reauthorized +// instead of only rehydrating auth context from propagated metadata. func WithIPCReauthRoutes(routes []string) StartOptions { return func(c StartConfig) StartConfig { c.IPCReauthRoutes = routes diff --git a/service/pkg/server/services.go b/service/pkg/server/services.go index 5af961893a..20e022ba5f 100644 --- a/service/pkg/server/services.go +++ b/service/pkg/server/services.go @@ -164,15 +164,9 @@ func startServices(ctx context.Context, params startServicesParams) (func(), err // If ns has log_level in config, create new logger with that level if err == nil { - if extractedLogLevel != cfg.Logger.Level { - slog.Debug("configuring logger") - newLoggerConfig := cfg.Logger - newLoggerConfig.Level = extractedLogLevel - newSvcLogger, err := logging.NewLogger(newLoggerConfig) - // only assign if logger successfully created - if err == nil { - svcLogger = newSvcLogger.With("namespace", ns) - } + svcLogger, err = buildNamespaceLogger(svcLogger, cfg, ns, extractedLogLevel) + if err != nil { + return func() {}, err } } @@ -287,6 +281,26 @@ func extractServiceLoggerConfig(cfg config.ServiceConfig) (string, error) { return "", fmt.Errorf("could not decode service log level: %w", err) } +func buildNamespaceLogger(baseLogger *logging.Logger, cfg *config.Config, ns, level string) (*logging.Logger, error) { + if level == cfg.Logger.Level { + return baseLogger, nil + } + + slog.Debug("configuring logger") + newLoggerConfig := cfg.Logger + newLoggerConfig.Level = level + + namespaceLogger, loggerErr := logging.NewLogger(newLoggerConfig) + if loggerErr != nil { + return nil, fmt.Errorf("invalid namespace logger config for %s: %w", ns, loggerErr) + } + + if err := namespaceLogger.Audit.ApplyConfig(cfg.Audit); err != nil { + return nil, fmt.Errorf("could not apply audit config for namespace %s: %w", ns, err) + } + return namespaceLogger.With("namespace", ns), nil +} + // newServiceDBClient creates a new database client for the specified namespace. // It initializes the client with the provided context, logger configuration, database configuration, // namespace, and migrations. It returns the created client and any error encountered during creation. diff --git a/service/pkg/server/services_test.go b/service/pkg/server/services_test.go index 0fc7130161..b33e72de20 100644 --- a/service/pkg/server/services_test.go +++ b/service/pkg/server/services_test.go @@ -220,6 +220,19 @@ func (suite *ServiceTestSuite) Test_RegisterServices_In_Mode_Core_Plus_Kas_Expec suite.Equal(serviceregistry.ModeERS.String(), ers.Mode) } +func (suite *ServiceTestSuite) TestBuildNamespaceLoggerRejectsInvalidOverrideLevel() { + baseLogger := logger.CreateTestLogger() + + cfg := &config.Config{ + Logger: logger.Config{Output: "stdout", Level: "info", Type: "json"}, + } + + namespaceLogger, err := buildNamespaceLogger(baseLogger, cfg, "policy", "not-a-level") + suite.Require().Error(err) + suite.Nil(namespaceLogger) + suite.ErrorContains(err, "invalid namespace logger config for policy") +} + func (suite *ServiceTestSuite) TestStartServicesWithVariousCases() { ctx := context.Background() diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index f3513a090d..ea023d5e31 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -112,6 +112,9 @@ func Start(f ...StartOptions) error { if err != nil { return fmt.Errorf("could not start logger: %w", err) } + if err := logger.Audit.ApplyConfig(cfg.Audit); err != nil { + return fmt.Errorf("could not apply audit config: %w", err) + } // Set default for places we can't pass the logger slog.SetDefault(logger.Logger)