From a726ca38c8a1c178911870bcf1df9da05a5c3115 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 12:48:54 -0700 Subject: [PATCH 01/20] feat(audit): add global config and context enrichment Signed-off-by: jakedoublev --- opentdf-core-mode.yaml | 15 +- opentdf-dev.yaml | 11 + opentdf-ers-mode.yaml | 12 ++ opentdf-ers-test.yaml | 13 +- opentdf-example.yaml | 11 + opentdf-kas-mode.yaml | 11 + service/internal/dotnotation/dotnotation.go | 78 +++++++ .../internal/dotnotation/dotnotation_test.go | 73 +++++++ service/internal/server/server.go | 2 +- service/logger/audit/config.go | 19 ++ .../logger/audit/contextServerInterceptor.go | 3 +- service/logger/audit/enrichment.go | 195 ++++++++++++++++++ service/logger/audit/logger.go | 11 +- service/logger/audit/logger_test.go | 109 +++++++++- service/logger/audit/utils.go | 33 +++ service/pkg/auth/context_auth.go | 92 ++++++++- service/pkg/auth/context_auth_test.go | 47 ++++- service/pkg/config/config.go | 5 + service/pkg/config/config_test.go | 12 ++ service/pkg/server/services.go | 1 + service/pkg/server/start.go | 1 + 21 files changed, 722 insertions(+), 32 deletions(-) create mode 100644 service/internal/dotnotation/dotnotation.go create mode 100644 service/internal/dotnotation/dotnotation_test.go create mode 100644 service/logger/audit/config.go create mode 100644 service/logger/audit/enrichment.go diff --git a/opentdf-core-mode.yaml b/opentdf-core-mode.yaml index 98401d19ee..d6a4113313 100644 --- a/opentdf-core-mode.yaml +++ b/opentdf-core-mode.yaml @@ -11,6 +11,19 @@ logger: level: debug type: text output: stderr +audit: + # Optional JWT claim mappings for audit enrichment. + # Paths use dot notation over the emitted audit log shape. + # jwt_claim_mappings: + # - claim: sub + # path: eventMetaData.requester.sub + # - claim: realm_access.roles + # path: eventMetaData.requester.roles + # Legacy shorthand: writes stringified values to eventMetaData.entityMetadata.. + # audited_entity_jwt_claims: + # - sub +services: + kas: # DB and Server configurations are defaulted for local development # db: # host: localhost @@ -57,4 +70,4 @@ server: maxage: 3600 grpc: reflectionEnabled: true # Default is false - port: 8383 \ No newline at end of file + port: 8383 diff --git a/opentdf-dev.yaml b/opentdf-dev.yaml index 93a471101f..5304f54f08 100644 --- a/opentdf-dev.yaml +++ b/opentdf-dev.yaml @@ -2,6 +2,17 @@ logger: level: debug type: text output: stderr +audit: + # Optional JWT claim mappings for audit enrichment. + # Paths use dot notation over the emitted audit log shape. + # jwt_claim_mappings: + # - claim: sub + # path: eventMetaData.requester.sub + # - claim: realm_access.roles + # path: eventMetaData.requester.roles + # Legacy shorthand: writes stringified values to eventMetaData.entityMetadata.. + # audited_entity_jwt_claims: + # - sub # DB and Server configurations are defaulted for local development # db: # host: localhost diff --git a/opentdf-ers-mode.yaml b/opentdf-ers-mode.yaml index 220317039c..662d10a298 100644 --- a/opentdf-ers-mode.yaml +++ b/opentdf-ers-mode.yaml @@ -5,7 +5,19 @@ logger: level: debug type: text output: stderr +audit: + # Optional JWT claim mappings for audit enrichment. + # Paths use dot notation over the emitted audit log shape. + # jwt_claim_mappings: + # - claim: sub + # path: eventMetaData.requester.sub + # - claim: realm_access.roles + # path: eventMetaData.requester.roles + # Legacy shorthand: writes stringified values to eventMetaData.entityMetadata.. + # audited_entity_jwt_claims: + # - sub services: + kas: entityresolution: log_level: info url: http://localhost:8888/auth diff --git a/opentdf-ers-test.yaml b/opentdf-ers-test.yaml index 38d4833592..fdbfaa3d16 100644 --- a/opentdf-ers-test.yaml +++ b/opentdf-ers-test.yaml @@ -12,6 +12,17 @@ mode: standalone logger: level: info type: json +audit: + # Optional JWT claim mappings for audit enrichment. + # Paths use dot notation over the emitted audit log shape. + # jwt_claim_mappings: + # - claim: sub + # path: eventMetaData.requester.sub + # - claim: realm_access.roles + # path: eventMetaData.requester.roles + # Legacy shorthand: writes stringified values to eventMetaData.entityMetadata.. + # audited_entity_jwt_claims: + # - sub crypto: type: standard @@ -256,4 +267,4 @@ development: # Allow insecure connections for testing allow_insecure: true # Disable some validations for easier testing - strict_validation: false \ No newline at end of file + strict_validation: false diff --git a/opentdf-example.yaml b/opentdf-example.yaml index a0b97da826..a3b2b3f6fb 100644 --- a/opentdf-example.yaml +++ b/opentdf-example.yaml @@ -2,6 +2,17 @@ logger: level: debug type: text output: stdout +audit: + # Optional JWT claim mappings for audit enrichment. + # Paths use dot notation over the emitted audit log shape. + # jwt_claim_mappings: + # - claim: sub + # path: eventMetaData.requester.sub + # - claim: realm_access.roles + # path: eventMetaData.requester.roles + # Legacy shorthand: writes stringified values to eventMetaData.entityMetadata.. + # audited_entity_jwt_claims: + # - sub # DB and Server configurations are defaulted for local development db: host: opentdfdb diff --git a/opentdf-kas-mode.yaml b/opentdf-kas-mode.yaml index ebcfb6f0c2..bd7a9009c2 100644 --- a/opentdf-kas-mode.yaml +++ b/opentdf-kas-mode.yaml @@ -11,6 +11,17 @@ logger: level: debug type: text output: stderr +audit: + # Optional JWT claim mappings for audit enrichment. + # Paths use dot notation over the emitted audit log shape. + # jwt_claim_mappings: + # - claim: sub + # path: eventMetaData.requester.sub + # - claim: realm_access.roles + # path: eventMetaData.requester.roles + # Legacy shorthand: writes stringified values to eventMetaData.entityMetadata.. + # audited_entity_jwt_claims: + # - sub security: unsafe: # Increase only when diagnosing clock drift issues; default is 1m diff --git a/service/internal/dotnotation/dotnotation.go b/service/internal/dotnotation/dotnotation.go new file mode 100644 index 0000000000..7a021fd834 --- /dev/null +++ b/service/internal/dotnotation/dotnotation.go @@ -0,0 +1,78 @@ +package dotnotation + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +// Get retrieves a value from a nested map using dot notation keys. +func Get(m map[string]any, key string) any { + keys := strings.Split(key, ".") + for i, k := range keys { + if i == len(keys)-1 { + return m[k] + } + if m[k] == nil { + return nil + } + var ok bool + m, ok = toMap(m[k]) + if !ok { + return nil + } + } + return nil +} + +// Set stores a value in a nested map using dot notation keys, creating +// intermediate maps as needed. +func Set(m map[string]any, key string, value any) error { + if key == "" { + return errors.New("empty path") + } + + keys := strings.Split(key, ".") + current := m + for i, k := range keys[:len(keys)-1] { + next, exists := current[k] + if !exists || next == nil { + child := map[string]any{} + current[k] = child + current = child + continue + } + + child, ok := toMap(next) + if !ok { + return fmt.Errorf("path collision at %s", strings.Join(keys[:i+1], ".")) + } + current[k] = child + current = child + } + + current[keys[len(keys)-1]] = value + return nil +} + +func toMap(value any) (map[string]any, bool) { + if value == nil { + return nil, false + } + if typed, ok := value.(map[string]any); ok { + return typed, true + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Map || rv.Type().Key().Kind() != reflect.String { + return nil, false + } + + out := make(map[string]any, rv.Len()) + iter := rv.MapRange() + for iter.Next() { + out[iter.Key().String()] = iter.Value().Interface() + } + return out, true +} diff --git a/service/internal/dotnotation/dotnotation_test.go b/service/internal/dotnotation/dotnotation_test.go new file mode 100644 index 0000000000..824abf4320 --- /dev/null +++ b/service/internal/dotnotation/dotnotation_test.go @@ -0,0 +1,73 @@ +package dotnotation + +import "testing" + +func TestGet(t *testing.T) { + tests := []struct { + name string + input map[string]any + key string + expected any + }{ + {name: "valid key", input: map[string]any{"a": map[string]any{"b": 1}}, key: "a.b", expected: 1}, + {name: "non-existent key", input: map[string]any{"a": map[string]any{"b": 1}}, key: "a.c", expected: nil}, + {name: "nested map", input: map[string]any{"a": map[string]any{"b": map[string]any{"c": 2}}}, key: "a.b.c", expected: 2}, + {name: "map string string", input: map[string]any{"a": map[string]string{"b": "value"}}, key: "a.b", expected: "value"}, + {name: "invalid key type", input: map[string]any{"a": 1}, key: "a.b", expected: nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Get(tt.input, tt.key) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestSet(t *testing.T) { + t.Run("creates nested maps", func(t *testing.T) { + input := map[string]any{} + + err := Set(input, "eventMetaData.requester.sub", "test-user") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got := Get(input, "eventMetaData.requester.sub"); got != "test-user" { + t.Fatalf("expected nested value, got %v", got) + } + }) + + t.Run("converts string keyed maps", func(t *testing.T) { + input := map[string]any{ + "eventMetaData": map[string]string{ + "existing": "value", + }, + } + + err := Set(input, "eventMetaData.requester.sub", "test-user") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got := Get(input, "eventMetaData.existing"); got != "value" { + t.Fatalf("expected existing value, got %v", got) + } + if got := Get(input, "eventMetaData.requester.sub"); got != "test-user" { + t.Fatalf("expected nested value, got %v", got) + } + }) + + t.Run("fails on non-map collision", func(t *testing.T) { + input := map[string]any{ + "eventMetaData": "not-a-map", + } + + err := Set(input, "eventMetaData.requester.sub", "test-user") + if err == nil { + t.Fatal("expected collision error") + } + }) +} diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 8d16b784f7..349c911fd2 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -542,7 +542,7 @@ func newConnectRPC(c Config, authInt connect.Interceptor, ints []connect.Interce // Add protovalidate interceptor validationInterceptor := validate.NewInterceptor() - interceptors = append(interceptors, connect.WithInterceptors(validationInterceptor, audit.ContextServerInterceptor(logger.Logger))) + interceptors = append(interceptors, connect.WithInterceptors(validationInterceptor, audit.ContextServerInterceptor(logger.Audit))) // Add any additional interceptors provided programmatically AFTER the default ones, so they have access needed context if len(ints) > 0 { diff --git a/service/logger/audit/config.go b/service/logger/audit/config.go new file mode 100644 index 0000000000..3b400b6172 --- /dev/null +++ b/service/logger/audit/config.go @@ -0,0 +1,19 @@ +package audit + +// JWTClaimMapping maps a JWT claim to a destination in the emitted audit log. +// The destination path uses dot notation over the audit log output shape, e.g. +// `eventMetaData.requester.sub`. +type JWTClaimMapping struct { + Claim string `mapstructure:"claim" json:"claim"` + Path string `mapstructure:"path" json:"path"` +} + +// Config contains platform-wide audit configuration. +type Config struct { + // AuditedEntityJWTClaims is a legacy shorthand that writes stringified claim + // values to `eventMetaData.entityMetadata.`. + AuditedEntityJWTClaims []string `mapstructure:"audited_entity_jwt_claims" json:"audited_entity_jwt_claims"` + // JWTClaimMappings writes JWT claims to configured destinations in the audit + // log output, preserving the original JSON value types where possible. + JWTClaimMappings []JWTClaimMapping `mapstructure:"jwt_claim_mappings" json:"jwt_claim_mappings"` +} diff --git a/service/logger/audit/contextServerInterceptor.go b/service/logger/audit/contextServerInterceptor.go index bbf69cf347..13a7032347 100644 --- a/service/logger/audit/contextServerInterceptor.go +++ b/service/logger/audit/contextServerInterceptor.go @@ -2,7 +2,6 @@ package audit import ( "context" - "log/slog" "net/http" "connectrpc.com/connect" @@ -13,7 +12,7 @@ import ( // ContextServerInterceptor allows audit events to track request state. // This is required for audit logging. -func ContextServerInterceptor(logger *slog.Logger) connect.UnaryInterceptorFunc { +func ContextServerInterceptor(logger *Logger) connect.UnaryInterceptorFunc { interceptor := func(next connect.UnaryFunc) connect.UnaryFunc { return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { // Get metadata from the context diff --git a/service/logger/audit/enrichment.go b/service/logger/audit/enrichment.go new file mode 100644 index 0000000000..d2d0490990 --- /dev/null +++ b/service/logger/audit/enrichment.go @@ -0,0 +1,195 @@ +package audit + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "reflect" + + internaldotnotation "github.com/opentdf/platform/service/internal/dotnotation" + ctxAuth "github.com/opentdf/platform/service/pkg/auth" +) + +var reservedClaimDestinationPaths = map[string]struct{}{ + "requestID": {}, + "timestamp": {}, + "action.type": {}, + "action.result": {}, + "object.type": {}, + "actor.id": {}, + "clientInfo.platform": {}, +} + +func (a *Logger) buildLogEntry(ctx context.Context, event *EventObject) map[string]any { + entry := event.logMap() + a.applyJWTClaimEnrichment(ctx, entry) + return entry +} + +func (a *Logger) applyJWTClaimEnrichment(ctx context.Context, entry map[string]any) { + if len(a.config.AuditedEntityJWTClaims) == 0 && len(a.config.JWTClaimMappings) == 0 { + return + } + + token := ctxAuth.GetAccessTokenFromContext(ctx, a.logger) + if token == nil { + return + } + + claimsMap, err := token.AsMap(ctx) + if err != nil { + a.logger.ErrorContext(ctx, "failed to read JWT claims for audit enrichment", slog.Any("error", err)) + return + } + + a.applyLegacyAuditedEntityClaims(ctx, entry, claimsMap) + a.applyMappedJWTClaims(ctx, entry, claimsMap) +} + +func (a *Logger) applyLegacyAuditedEntityClaims(ctx context.Context, entry map[string]any, claimsMap map[string]any) { + if len(a.config.AuditedEntityJWTClaims) == 0 { + return + } + + entityMetadata, err := getOrCreateMapAtPath(entry, "eventMetaData.entityMetadata") + if err != nil { + a.logger.ErrorContext(ctx, "failed to create legacy audit entity metadata destination", slog.Any("error", err)) + return + } + + for _, claim := range a.config.AuditedEntityJWTClaims { + if claim == "" { + continue + } + + value := internaldotnotation.Get(claimsMap, claim) + if value == nil { + continue + } + + entityMetadata[claim] = stringifyJWTClaimValue(value) + } +} + +func (a *Logger) applyMappedJWTClaims(ctx context.Context, entry map[string]any, claimsMap map[string]any) { + for _, mapping := range a.config.JWTClaimMappings { + if mapping.Claim == "" || mapping.Path == "" { + continue + } + if _, reserved := reservedClaimDestinationPaths[mapping.Path]; reserved { + a.logger.ErrorContext(ctx, + "refusing to write JWT claim to reserved audit path", + slog.String("claim", mapping.Claim), + slog.String("path", mapping.Path), + ) + continue + } + + value := internaldotnotation.Get(claimsMap, mapping.Claim) + if value == nil { + continue + } + + if err := internaldotnotation.Set(entry, mapping.Path, normalizeAuditValue(value)); err != nil { + a.logger.ErrorContext(ctx, + "failed to apply JWT claim mapping to audit log", + slog.String("claim", mapping.Claim), + slog.String("path", mapping.Path), + slog.Any("error", err), + ) + } + } +} + +func getOrCreateMapAtPath(root map[string]any, path string) (map[string]any, error) { + if path == "" { + return root, nil + } + + current := internaldotnotation.Get(root, path) + if current == nil { + if err := internaldotnotation.Set(root, path, map[string]any{}); err != nil { + return nil, err + } + current = internaldotnotation.Get(root, path) + } + + converted, ok := normalizeAuditValue(current).(map[string]any) + if !ok { + return nil, fmt.Errorf("path [%s] does not resolve to an object", path) + } + + if err := internaldotnotation.Set(root, path, converted); err != nil { + return nil, err + } + return converted, nil +} + +func normalizeAuditValue(value any) any { + switch typed := value.(type) { + case nil: + return nil + case map[string]any: + if typed == nil { + return nil + } + normalized := make(map[string]any, len(typed)) + for key, nested := range typed { + normalized[key] = normalizeAuditValue(nested) + } + return normalized + case []any: + if typed == nil { + return nil + } + normalized := make([]any, len(typed)) + for idx, nested := range typed { + normalized[idx] = normalizeAuditValue(nested) + } + return normalized + } + + rv := reflect.ValueOf(value) + //nolint:exhaustive // only composite kinds need normalization; scalars can pass through unchanged + switch rv.Kind() { + case reflect.Map: + if rv.IsNil() { + return nil + } + if rv.Type().Key().Kind() != reflect.String { + return value + } + normalized := make(map[string]any, rv.Len()) + iter := rv.MapRange() + for iter.Next() { + normalized[iter.Key().String()] = normalizeAuditValue(iter.Value().Interface()) + } + return normalized + case reflect.Slice, reflect.Array: + if rv.Kind() == reflect.Slice && rv.IsNil() { + return nil + } + normalized := make([]any, rv.Len()) + for idx := range normalized { + normalized[idx] = normalizeAuditValue(rv.Index(idx).Interface()) + } + return normalized + default: + return value + } +} + +func stringifyJWTClaimValue(value any) string { + switch typed := value.(type) { + case string: + return typed + default: + normalized := normalizeAuditValue(typed) + encoded, err := json.Marshal(normalized) + if err != nil { + return fmt.Sprintf("%v", normalized) + } + return string(encoded) + } +} diff --git a/service/logger/audit/logger.go b/service/logger/audit/logger.go index 56060c6f56..4ac7a70c07 100644 --- a/service/logger/audit/logger.go +++ b/service/logger/audit/logger.go @@ -34,6 +34,7 @@ var logLevelNames = map[slog.Leveler]string{ type Logger struct { logger *slog.Logger + config Config } // Used to support custom log levels showing up with custom labels as well @@ -61,10 +62,16 @@ func CreateAuditLogger(logger slog.Logger) *Logger { } } +// ApplyConfig stores the latest audit enrichment configuration. +func (a *Logger) ApplyConfig(cfg Config) { + a.config = cfg +} + func (a *Logger) With(key string, value string) *Logger { return &Logger{ //nolint:sloglint // custom logger should support key/value pairs in With attributes logger: a.logger.With(key, value), + config: a.config, } } @@ -81,7 +88,7 @@ func (tx *auditTransaction) addEvent(verb Verb, event *EventObject) { // logClose completes an audit transaction and emits all recorded events. // If success is false or err is not nil, events are logged as "cancelled" with the error attached. // Otherwise, events are logged with their originally recorded success/failure status. -func (tx *auditTransaction) logClose(ctx context.Context, logger *slog.Logger, success bool, err error) { +func (tx *auditTransaction) logClose(ctx context.Context, auditLogger *Logger, success bool, err error) { tx.mu.Lock() defer tx.mu.Unlock() for _, event := range tx.events { @@ -99,7 +106,7 @@ func (tx *auditTransaction) logClose(ctx context.Context, logger *slog.Logger, s } //nolint:sloglint // audit message is always just the verb - logger.Log(ctx, LevelAudit, string(event.verb), slog.Any("audit", *auditEvent)) + auditLogger.logger.Log(ctx, LevelAudit, string(event.verb), slog.Any("audit", auditLogger.buildLogEntry(ctx, auditEvent))) } } diff --git a/service/logger/audit/logger_test.go b/service/logger/audit/logger_test.go index c254f6dfe4..67c3284be9 100644 --- a/service/logger/audit/logger_test.go +++ b/service/logger/audit/logger_test.go @@ -11,9 +11,12 @@ import ( "time" "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/protocol/go/authorization" + ctxAuth "github.com/opentdf/platform/service/pkg/auth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" ) // Params @@ -92,20 +95,23 @@ func extractLogEntry(t *testing.T, logBuffer *bytes.Buffer) (logEntryStructure, } func doWithLogger(t *testing.T, testFunc func(ctx context.Context, l *Logger)) (ls logEntryStructure, lt time.Time) { //nolint:nonamedreturns // Required to rewrite on panics, right? + return doWithLoggerContext(createTestContext(t), t, testFunc) +} + +func doWithLoggerContext(ctx context.Context, t *testing.T, testFunc func(ctx context.Context, l *Logger)) (ls logEntryStructure, lt time.Time) { //nolint:nonamedreturns // Required to rewrite on panics, right? l, buf := createTestLogger() - ctx := createTestContext(t) tx, ok := ctx.Value(contextKey{}).(*auditTransaction) require.True(t, ok, "audit transaction missing from context") defer func() { if r := recover(); r != nil { if err, okerr := r.(error); okerr { - tx.logClose(ctx, l.logger, false, err) + tx.logClose(ctx, l, false, err) } else { - tx.logClose(ctx, l.logger, false, nil) + tx.logClose(ctx, l, false, nil) } } else { - tx.logClose(ctx, l.logger, true, nil) + tx.logClose(ctx, l, true, nil) } ls, lt = extractLogEntry(t, buf) }() @@ -115,6 +121,30 @@ func doWithLogger(t *testing.T, testFunc func(ctx context.Context, l *Logger)) ( return ls, lt } +func createTestJWTForAudit(t *testing.T) (jwt.Token, string) { + t.Helper() + + token, err := jwt.NewBuilder(). + Subject("jwt-user"). + Claim("realm_access", map[string]any{"roles": []string{"admin", "user"}}). + Claim("email_verified", true). + Build() + require.NoError(t, err) + + rawToken, err := jwt.Sign(token, jwt.WithInsecureNoSignature()) + require.NoError(t, err) + + return token, string(rawToken) +} + +func decodeAuditPayload(t *testing.T, payload json.RawMessage) map[string]any { + t.Helper() + + var decoded map[string]any + require.NoError(t, json.Unmarshal(payload, &decoded)) + return decoded +} + func TestAuditRewrapSuccess(t *testing.T) { logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { l.RewrapSuccess(ctx, rewrapParams) @@ -333,6 +363,77 @@ func TestPolicyCrudFailure(t *testing.T) { assert.JSONEq(t, expectedAuditLog, loggedMessage) } +func TestAuditJWTClaimMappingsApplyToPolicyAudit(t *testing.T) { + token, rawToken := createTestJWTForAudit(t) + ctx := ctxAuth.ContextWithAuthNInfo(createTestContext(t), nil, token, rawToken) + + logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { + l.ApplyConfig(Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.requester.sub"}, + {Claim: "realm_access.roles", Path: "eventMetaData.requester.roles"}, + {Claim: "email_verified", Path: "eventMetaData.requester.emailVerified"}, + }, + }) + l.PolicyCRUDSuccess(ctx, policyCRUDParams) + }) + + payload := decodeAuditPayload(t, logEntry.Audit) + eventMetaData, ok := payload["eventMetaData"].(map[string]any) + require.True(t, ok) + requester, ok := eventMetaData["requester"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "jwt-user", requester["sub"]) + assert.Equal(t, []any{"admin", "user"}, requester["roles"]) + assert.Equal(t, true, requester["emailVerified"]) +} + +func TestAuditJWTClaimMappingsUseMetadataFallback(t *testing.T) { + _, rawToken := createTestJWTForAudit(t) + ctx := metadata.NewIncomingContext(createTestContext(t), metadata.Pairs(ctxAuth.AccessTokenKey, rawToken)) + + logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { + l.ApplyConfig(Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.requester.sub"}, + }, + }) + l.PolicyCRUDSuccess(ctx, policyCRUDParams) + }) + + payload := decodeAuditPayload(t, logEntry.Audit) + eventMetaData, ok := payload["eventMetaData"].(map[string]any) + require.True(t, ok) + requester, ok := eventMetaData["requester"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "jwt-user", requester["sub"]) +} + +func TestAuditLegacyAuditedEntityJWTClaimsApplyToPolicyAudit(t *testing.T) { + token, rawToken := createTestJWTForAudit(t) + ctx := ctxAuth.ContextWithAuthNInfo(createTestContext(t), nil, token, rawToken) + + logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { + l.ApplyConfig(Config{ + AuditedEntityJWTClaims: []string{ + "sub", + "realm_access.roles", + "email_verified", + }, + }) + l.PolicyCRUDSuccess(ctx, policyCRUDParams) + }) + + payload := decodeAuditPayload(t, logEntry.Audit) + eventMetaData, ok := payload["eventMetaData"].(map[string]any) + require.True(t, ok) + entityMetadata, ok := eventMetaData["entityMetadata"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "jwt-user", entityMetadata["sub"]) + assert.Equal(t, `["admin","user"]`, entityMetadata["realm_access.roles"]) + assert.Equal(t, "true", entityMetadata["email_verified"]) +} + func TestDeferredRewrapSuccess(t *testing.T) { logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { l.RewrapSuccess(ctx, rewrapParams) diff --git a/service/logger/audit/utils.go b/service/logger/audit/utils.go index 2d1bf12e88..e87f18ebb4 100644 --- a/service/logger/audit/utils.go +++ b/service/logger/audit/utils.go @@ -41,6 +41,39 @@ func (e EventObject) LogValue() slog.Value { slog.String("timestamp", e.Timestamp)) } +func (e EventObject) logMap() map[string]any { + return map[string]any{ + "object": map[string]any{ + "type": e.Object.Type.String(), + "id": e.Object.ID, + "name": e.Object.Name, + "attributes": map[string]any{ + "assertions": e.Object.Attributes.Assertions, + "attrs": e.Object.Attributes.Attrs, + "permissions": e.Object.Attributes.Permissions, + }, + }, + "action": map[string]any{ + "type": e.Action.Type.String(), + "result": e.Action.Result.String(), + }, + "actor": map[string]any{ + "id": e.Actor.ID, + "attributes": e.Actor.Attributes, + }, + "eventMetaData": normalizeAuditValue(e.EventMetaData), + "clientInfo": map[string]any{ + "userAgent": e.ClientInfo.UserAgent, + "platform": e.ClientInfo.Platform, + "requestIP": e.ClientInfo.RequestIP, + }, + "original": normalizeAuditValue(e.Original), + "updated": normalizeAuditValue(e.Updated), + "requestID": e.RequestID.String(), + "timestamp": e.Timestamp, + } +} + // event.object type auditEventObject struct { Type ObjectType `json:"type"` diff --git a/service/pkg/auth/context_auth.go b/service/pkg/auth/context_auth.go index 91005e42b6..904ba63e8c 100644 --- a/service/pkg/auth/context_auth.go +++ b/service/pkg/auth/context_auth.go @@ -3,10 +3,10 @@ package auth import ( "context" "errors" + "strings" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" - "github.com/opentdf/platform/service/logger" "google.golang.org/grpc/metadata" ) @@ -30,6 +30,10 @@ type authContext struct { rawToken string } +type contextErrorLogger interface { + ErrorContext(context.Context, string, ...any) +} + func ContextWithAuthNInfo(ctx context.Context, key jwk.Key, accessToken jwt.Token, raw string) context.Context { return context.WithValue(ctx, authnContextKey, &authContext{ key, @@ -38,7 +42,7 @@ func ContextWithAuthNInfo(ctx context.Context, key jwk.Key, accessToken jwt.Toke }) } -func getContextDetails(ctx context.Context, l *logger.Logger) *authContext { +func getContextDetails(ctx context.Context, l contextErrorLogger) *authContext { key := ctx.Value(authnContextKey) if key == nil { return nil @@ -47,29 +51,56 @@ func getContextDetails(ctx context.Context, l *logger.Logger) *authContext { return c } - // We should probably return an error here? - l.ErrorContext(ctx, "invalid authContext") + if l != nil { + l.ErrorContext(ctx, "invalid authContext") + } return nil } -func GetJWKFromContext(ctx context.Context, l *logger.Logger) jwk.Key { +func GetJWKFromContext(ctx context.Context, l contextErrorLogger) jwk.Key { if c := getContextDetails(ctx, l); c != nil { return c.key } return nil } -func GetAccessTokenFromContext(ctx context.Context, l *logger.Logger) jwt.Token { +func GetAccessTokenFromContext(ctx context.Context, l contextErrorLogger) jwt.Token { if c := getContextDetails(ctx, l); c != nil { - return c.accessToken + if c.accessToken != nil { + return c.accessToken + } } - return nil + + rawToken := GetRawAccessTokenFromContext(ctx, l) + if rawToken == "" { + return nil + } + + parsed, err := jwt.Parse([]byte(rawToken), jwt.WithVerify(false), jwt.WithValidate(false)) + if err != nil { + if l != nil { + l.ErrorContext(ctx, "failed to parse access token from context", "error", err) + } + return nil + } + + return parsed } -func GetRawAccessTokenFromContext(ctx context.Context, l *logger.Logger) string { +func GetRawAccessTokenFromContext(ctx context.Context, l contextErrorLogger) string { if c := getContextDetails(ctx, l); c != nil { - return c.rawToken + if c.rawToken != "" { + return c.rawToken + } + } + + if rawToken := getRawAccessTokenFromMetadata(ctx, true); rawToken != "" { + return rawToken + } + if rawToken := getRawAccessTokenFromMetadata(ctx, false); rawToken != "" { + return rawToken } + return "" } @@ -77,7 +108,7 @@ func GetRawAccessTokenFromContext(ctx context.Context, l *logger.Logger) string // // Adding the authn info to gRPC metadata propagates it across services rather than strictly // in-process within Go alone -func EnrichIncomingContextMetadataWithAuthn(ctx context.Context, l *logger.Logger, clientID string) context.Context { +func EnrichIncomingContextMetadataWithAuthn(ctx context.Context, l contextErrorLogger, clientID string) context.Context { rawToken := GetRawAccessTokenFromContext(ctx, l) md, ok := metadata.FromIncomingContext(ctx) @@ -97,6 +128,45 @@ func EnrichIncomingContextMetadataWithAuthn(ctx context.Context, l *logger.Logge return metadata.NewIncomingContext(ctx, md) } +func getRawAccessTokenFromMetadata(ctx context.Context, incoming bool) string { + var ( + md metadata.MD + ok bool + ) + if incoming { + md, ok = metadata.FromIncomingContext(ctx) + } else { + md, ok = metadata.FromOutgoingContext(ctx) + } + if !ok { + return "" + } + + if accessTokens := md.Get(AccessTokenKey); len(accessTokens) > 0 && accessTokens[0] != "" { + return accessTokens[0] + } + + authHeaders := md.Get("authorization") + if len(authHeaders) == 0 { + authHeaders = md.Get("Authorization") + } + if len(authHeaders) > 0 { + return trimAuthorizationHeader(authHeaders[0]) + } + return "" +} + +func trimAuthorizationHeader(header string) string { + switch { + case strings.HasPrefix(header, "Bearer "): + return strings.TrimPrefix(header, "Bearer ") + case strings.HasPrefix(header, "DPoP "): + return strings.TrimPrefix(header, "DPoP ") + default: + return header + } +} + // GetClientIDFromContext retrieves the client ID from the metadata in the context func GetClientIDFromContext(ctx context.Context, incoming bool) (string, error) { var ( diff --git a/service/pkg/auth/context_auth_test.go b/service/pkg/auth/context_auth_test.go index 63354eb7f4..e09b193b3b 100644 --- a/service/pkg/auth/context_auth_test.go +++ b/service/pkg/auth/context_auth_test.go @@ -6,7 +6,6 @@ import ( "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" - "github.com/opentdf/platform/service/logger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc/metadata" @@ -38,7 +37,7 @@ func TestGetJWKFromContext(t *testing.T) { ctx := ContextWithAuthNInfo(t.Context(), mockJWK, nil, "") // Retrieve the JWK and assert - retrievedJWK := GetJWKFromContext(ctx, logger.CreateTestLogger()) + retrievedJWK := GetJWKFromContext(ctx, nil) assert.NotNil(t, retrievedJWK, "JWK should not be nil") assert.Equal(t, mockJWK, retrievedJWK, "Retrieved JWK should match the mock JWK") } @@ -49,7 +48,7 @@ func TestGetAccessTokenFromContext(t *testing.T) { ctx := ContextWithAuthNInfo(t.Context(), nil, mockJWT, "") // Retrieve the JWT and assert - retrievedJWT := GetAccessTokenFromContext(ctx, logger.CreateTestLogger()) + retrievedJWT := GetAccessTokenFromContext(ctx, nil) assert.NotNil(t, retrievedJWT, "Access token should not be nil") assert.Equal(t, mockJWT, retrievedJWT, "Retrieved JWT should match the mock JWT") } @@ -60,26 +59,54 @@ func TestGetRawAccessTokenFromContext(t *testing.T) { ctx := ContextWithAuthNInfo(t.Context(), nil, nil, rawToken) // Retrieve the raw token and assert - retrievedRawToken := GetRawAccessTokenFromContext(ctx, logger.CreateTestLogger()) + retrievedRawToken := GetRawAccessTokenFromContext(ctx, nil) assert.Equal(t, rawToken, retrievedRawToken, "Retrieved raw token should match the mock raw token") } +func TestGetRawAccessTokenFromContextFallsBackToMetadata(t *testing.T) { + t.Run("incoming access token metadata", func(t *testing.T) { + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs(AccessTokenKey, "incoming-token")) + retrievedRawToken := GetRawAccessTokenFromContext(ctx, nil) + assert.Equal(t, "incoming-token", retrievedRawToken) + }) + + t.Run("outgoing authorization metadata", func(t *testing.T) { + ctx := metadata.NewOutgoingContext(t.Context(), metadata.Pairs("Authorization", "Bearer outgoing-token")) + retrievedRawToken := GetRawAccessTokenFromContext(ctx, nil) + assert.Equal(t, "outgoing-token", retrievedRawToken) + }) +} + +func TestGetAccessTokenFromContextFallsBackToMetadata(t *testing.T) { + mockJWT, err := jwt.NewBuilder(). + Subject("metadata-user"). + Claim("roles", []string{"admin"}). + Build() + require.NoError(t, err) + + rawToken, err := jwt.Sign(mockJWT, jwt.WithInsecureNoSignature()) + require.NoError(t, err) + + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs(AccessTokenKey, string(rawToken))) + retrievedJWT := GetAccessTokenFromContext(ctx, nil) + require.NotNil(t, retrievedJWT) + assert.Equal(t, "metadata-user", retrievedJWT.Subject()) +} + func TestGetContextDetailsInvalidType(t *testing.T) { // Create a context with an invalid type ctx := context.WithValue(t.Context(), authnContextKey, "invalidType") // Assert that GetJWKFromContext handles the invalid type correctly - retrievedJWK := GetJWKFromContext(ctx, logger.CreateTestLogger()) + retrievedJWK := GetJWKFromContext(ctx, nil) assert.Nil(t, retrievedJWK, "JWK should be nil when context value is invalid") } func TestEnrichIncomingContextMetadataWithAuthn(t *testing.T) { mockClientID := "test-client-id" - l := logger.CreateTestLogger() - t.Run("should add access token and client id to metadata", func(t *testing.T) { ctx := ContextWithAuthNInfo(t.Context(), nil, nil, "raw-token-string") - enrichedCtx := EnrichIncomingContextMetadataWithAuthn(ctx, l, mockClientID) + enrichedCtx := EnrichIncomingContextMetadataWithAuthn(ctx, nil, mockClientID) md, ok := metadata.FromIncomingContext(enrichedCtx) require.True(t, ok) @@ -95,7 +122,7 @@ func TestEnrichIncomingContextMetadataWithAuthn(t *testing.T) { t.Run("should not set client id if empty", func(t *testing.T) { ctx := ContextWithAuthNInfo(t.Context(), nil, nil, "raw-token-string") - enrichedCtx := EnrichIncomingContextMetadataWithAuthn(ctx, l, "") + enrichedCtx := EnrichIncomingContextMetadataWithAuthn(ctx, nil, "") md, ok := metadata.FromIncomingContext(enrichedCtx) require.True(t, ok) @@ -108,7 +135,7 @@ func TestEnrichIncomingContextMetadataWithAuthn(t *testing.T) { originalMD := metadata.New(map[string]string{"original-key": "original-value"}) ctx := metadata.NewIncomingContext(t.Context(), originalMD) ctx = ContextWithAuthNInfo(ctx, nil, nil, "raw-token-string") - enrichedCtx := EnrichIncomingContextMetadataWithAuthn(ctx, l, mockClientID) + enrichedCtx := EnrichIncomingContextMetadataWithAuthn(ctx, nil, mockClientID) md, ok := metadata.FromIncomingContext(enrichedCtx) require.True(t, ok) diff --git a/service/pkg/config/config.go b/service/pkg/config/config.go index aac79e7fbc..acbc1f3b28 100644 --- a/service/pkg/config/config.go +++ b/service/pkg/config/config.go @@ -10,6 +10,7 @@ import ( "github.com/go-playground/validator/v10" "github.com/opentdf/platform/service/internal/server" "github.com/opentdf/platform/service/logger" + auditcfg "github.com/opentdf/platform/service/logger/audit" "github.com/opentdf/platform/service/pkg/db" "github.com/spf13/viper" ) @@ -76,6 +77,9 @@ type Config struct { // Logger represents the configuration settings for the logger. Logger logger.Config `mapstructure:"logger" json:"logger"` + // Audit represents the configuration settings for audit enrichment. + Audit auditcfg.Config `mapstructure:"audit" json:"audit"` + // Mode specifies which services to run. // By default, it runs all services. Mode []string `mapstructure:"mode" json:"mode" default:"[\"all\"]"` @@ -136,6 +140,7 @@ var ( func (c *Config) LogValue() slog.Value { return slog.GroupValue( slog.Bool("dev_mode", c.DevMode), + slog.Any("audit", c.Audit), slog.Any("db", c.DB), slog.Any("logger", c.Logger), slog.Any("mode", c.Mode), diff --git a/service/pkg/config/config_test.go b/service/pkg/config/config_test.go index 4ce51adc42..105d7c6a5a 100644 --- a/service/pkg/config/config_test.go +++ b/service/pkg/config/config_test.go @@ -6,6 +6,7 @@ import ( "os" "testing" + auditcfg "github.com/opentdf/platform/service/logger/audit" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -388,11 +389,22 @@ server: port: 9090 logger: level: warn +audit: + audited_entity_jwt_claims: + - sub + - realm_access.roles + jwt_claim_mappings: + - claim: email_verified + path: eventMetaData.requester.emailVerified `, asserts: func(t *testing.T, cfg *Config) { // Values from file assert.Equal(t, 9090, cfg.Server.Port) assert.Equal(t, "warn", cfg.Logger.Level) + assert.Equal(t, []string{"sub", "realm_access.roles"}, cfg.Audit.AuditedEntityJWTClaims) + assert.Equal(t, []auditcfg.JWTClaimMapping{ + {Claim: "email_verified", Path: "eventMetaData.requester.emailVerified"}, + }, cfg.Audit.JWTClaimMappings) // Value from defaults assert.Equal(t, []string{"all"}, cfg.Mode) }, diff --git a/service/pkg/server/services.go b/service/pkg/server/services.go index 5af961893a..7a2d3df3a2 100644 --- a/service/pkg/server/services.go +++ b/service/pkg/server/services.go @@ -171,6 +171,7 @@ func startServices(ctx context.Context, params startServicesParams) (func(), err newSvcLogger, err := logging.NewLogger(newLoggerConfig) // only assign if logger successfully created if err == nil { + newSvcLogger.Audit.ApplyConfig(cfg.Audit) svcLogger = newSvcLogger.With("namespace", ns) } } diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index f3513a090d..a9a39e54dd 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -112,6 +112,7 @@ func Start(f ...StartOptions) error { if err != nil { return fmt.Errorf("could not start logger: %w", err) } + logger.Audit.ApplyConfig(cfg.Audit) // Set default for places we can't pass the logger slog.SetDefault(logger.Logger) From 81f19d54637f1960e403863affd89baadbbec2ac Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 12:55:21 -0700 Subject: [PATCH 02/20] refactor(audit): drop legacy claim shorthand Signed-off-by: jakedoublev --- opentdf-core-mode.yaml | 3 -- opentdf-dev.yaml | 3 -- opentdf-ers-mode.yaml | 3 -- opentdf-ers-test.yaml | 3 -- opentdf-example.yaml | 3 -- opentdf-kas-mode.yaml | 3 -- service/logger/audit/config.go | 3 -- service/logger/audit/enrichment.go | 68 +---------------------------- service/logger/audit/logger_test.go | 14 +++--- service/pkg/config/config_test.go | 4 -- 10 files changed, 8 insertions(+), 99 deletions(-) diff --git a/opentdf-core-mode.yaml b/opentdf-core-mode.yaml index d6a4113313..a3177d4a63 100644 --- a/opentdf-core-mode.yaml +++ b/opentdf-core-mode.yaml @@ -19,9 +19,6 @@ audit: # path: eventMetaData.requester.sub # - claim: realm_access.roles # path: eventMetaData.requester.roles - # Legacy shorthand: writes stringified values to eventMetaData.entityMetadata.. - # audited_entity_jwt_claims: - # - sub services: kas: # DB and Server configurations are defaulted for local development diff --git a/opentdf-dev.yaml b/opentdf-dev.yaml index 5304f54f08..dd2050e30f 100644 --- a/opentdf-dev.yaml +++ b/opentdf-dev.yaml @@ -10,9 +10,6 @@ audit: # path: eventMetaData.requester.sub # - claim: realm_access.roles # path: eventMetaData.requester.roles - # Legacy shorthand: writes stringified values to eventMetaData.entityMetadata.. - # audited_entity_jwt_claims: - # - sub # DB and Server configurations are defaulted for local development # db: # host: localhost diff --git a/opentdf-ers-mode.yaml b/opentdf-ers-mode.yaml index 662d10a298..aa38c8d23e 100644 --- a/opentdf-ers-mode.yaml +++ b/opentdf-ers-mode.yaml @@ -13,9 +13,6 @@ audit: # path: eventMetaData.requester.sub # - claim: realm_access.roles # path: eventMetaData.requester.roles - # Legacy shorthand: writes stringified values to eventMetaData.entityMetadata.. - # audited_entity_jwt_claims: - # - sub services: kas: entityresolution: diff --git a/opentdf-ers-test.yaml b/opentdf-ers-test.yaml index fdbfaa3d16..e46eb43f06 100644 --- a/opentdf-ers-test.yaml +++ b/opentdf-ers-test.yaml @@ -20,9 +20,6 @@ audit: # path: eventMetaData.requester.sub # - claim: realm_access.roles # path: eventMetaData.requester.roles - # Legacy shorthand: writes stringified values to eventMetaData.entityMetadata.. - # audited_entity_jwt_claims: - # - sub crypto: type: standard diff --git a/opentdf-example.yaml b/opentdf-example.yaml index a3b2b3f6fb..a397f20ff3 100644 --- a/opentdf-example.yaml +++ b/opentdf-example.yaml @@ -10,9 +10,6 @@ audit: # path: eventMetaData.requester.sub # - claim: realm_access.roles # path: eventMetaData.requester.roles - # Legacy shorthand: writes stringified values to eventMetaData.entityMetadata.. - # audited_entity_jwt_claims: - # - sub # DB and Server configurations are defaulted for local development db: host: opentdfdb diff --git a/opentdf-kas-mode.yaml b/opentdf-kas-mode.yaml index bd7a9009c2..8359710fdf 100644 --- a/opentdf-kas-mode.yaml +++ b/opentdf-kas-mode.yaml @@ -19,9 +19,6 @@ audit: # path: eventMetaData.requester.sub # - claim: realm_access.roles # path: eventMetaData.requester.roles - # Legacy shorthand: writes stringified values to eventMetaData.entityMetadata.. - # audited_entity_jwt_claims: - # - sub security: unsafe: # Increase only when diagnosing clock drift issues; default is 1m diff --git a/service/logger/audit/config.go b/service/logger/audit/config.go index 3b400b6172..1dc8761abc 100644 --- a/service/logger/audit/config.go +++ b/service/logger/audit/config.go @@ -10,9 +10,6 @@ type JWTClaimMapping struct { // Config contains platform-wide audit configuration. type Config struct { - // AuditedEntityJWTClaims is a legacy shorthand that writes stringified claim - // values to `eventMetaData.entityMetadata.`. - AuditedEntityJWTClaims []string `mapstructure:"audited_entity_jwt_claims" json:"audited_entity_jwt_claims"` // JWTClaimMappings writes JWT claims to configured destinations in the audit // log output, preserving the original JSON value types where possible. JWTClaimMappings []JWTClaimMapping `mapstructure:"jwt_claim_mappings" json:"jwt_claim_mappings"` diff --git a/service/logger/audit/enrichment.go b/service/logger/audit/enrichment.go index d2d0490990..0ae0776c4a 100644 --- a/service/logger/audit/enrichment.go +++ b/service/logger/audit/enrichment.go @@ -2,8 +2,6 @@ package audit import ( "context" - "encoding/json" - "fmt" "log/slog" "reflect" @@ -28,7 +26,7 @@ func (a *Logger) buildLogEntry(ctx context.Context, event *EventObject) map[stri } func (a *Logger) applyJWTClaimEnrichment(ctx context.Context, entry map[string]any) { - if len(a.config.AuditedEntityJWTClaims) == 0 && len(a.config.JWTClaimMappings) == 0 { + if len(a.config.JWTClaimMappings) == 0 { return } @@ -43,35 +41,9 @@ func (a *Logger) applyJWTClaimEnrichment(ctx context.Context, entry map[string]a return } - a.applyLegacyAuditedEntityClaims(ctx, entry, claimsMap) a.applyMappedJWTClaims(ctx, entry, claimsMap) } -func (a *Logger) applyLegacyAuditedEntityClaims(ctx context.Context, entry map[string]any, claimsMap map[string]any) { - if len(a.config.AuditedEntityJWTClaims) == 0 { - return - } - - entityMetadata, err := getOrCreateMapAtPath(entry, "eventMetaData.entityMetadata") - if err != nil { - a.logger.ErrorContext(ctx, "failed to create legacy audit entity metadata destination", slog.Any("error", err)) - return - } - - for _, claim := range a.config.AuditedEntityJWTClaims { - if claim == "" { - continue - } - - value := internaldotnotation.Get(claimsMap, claim) - if value == nil { - continue - } - - entityMetadata[claim] = stringifyJWTClaimValue(value) - } -} - func (a *Logger) applyMappedJWTClaims(ctx context.Context, entry map[string]any, claimsMap map[string]any) { for _, mapping := range a.config.JWTClaimMappings { if mapping.Claim == "" || mapping.Path == "" { @@ -102,30 +74,6 @@ func (a *Logger) applyMappedJWTClaims(ctx context.Context, entry map[string]any, } } -func getOrCreateMapAtPath(root map[string]any, path string) (map[string]any, error) { - if path == "" { - return root, nil - } - - current := internaldotnotation.Get(root, path) - if current == nil { - if err := internaldotnotation.Set(root, path, map[string]any{}); err != nil { - return nil, err - } - current = internaldotnotation.Get(root, path) - } - - converted, ok := normalizeAuditValue(current).(map[string]any) - if !ok { - return nil, fmt.Errorf("path [%s] does not resolve to an object", path) - } - - if err := internaldotnotation.Set(root, path, converted); err != nil { - return nil, err - } - return converted, nil -} - func normalizeAuditValue(value any) any { switch typed := value.(type) { case nil: @@ -179,17 +127,3 @@ func normalizeAuditValue(value any) any { return value } } - -func stringifyJWTClaimValue(value any) string { - switch typed := value.(type) { - case string: - return typed - default: - normalized := normalizeAuditValue(typed) - encoded, err := json.Marshal(normalized) - if err != nil { - return fmt.Sprintf("%v", normalized) - } - return string(encoded) - } -} diff --git a/service/logger/audit/logger_test.go b/service/logger/audit/logger_test.go index 67c3284be9..00a3e2bb6d 100644 --- a/service/logger/audit/logger_test.go +++ b/service/logger/audit/logger_test.go @@ -409,16 +409,16 @@ func TestAuditJWTClaimMappingsUseMetadataFallback(t *testing.T) { assert.Equal(t, "jwt-user", requester["sub"]) } -func TestAuditLegacyAuditedEntityJWTClaimsApplyToPolicyAudit(t *testing.T) { +func TestAuditJWTClaimMappingsCanWriteToEntityMetadata(t *testing.T) { token, rawToken := createTestJWTForAudit(t) ctx := ctxAuth.ContextWithAuthNInfo(createTestContext(t), nil, token, rawToken) logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { l.ApplyConfig(Config{ - AuditedEntityJWTClaims: []string{ - "sub", - "realm_access.roles", - "email_verified", + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.entityMetadata.sub"}, + {Claim: "realm_access.roles", Path: "eventMetaData.entityMetadata.roles"}, + {Claim: "email_verified", Path: "eventMetaData.entityMetadata.emailVerified"}, }, }) l.PolicyCRUDSuccess(ctx, policyCRUDParams) @@ -430,8 +430,8 @@ func TestAuditLegacyAuditedEntityJWTClaimsApplyToPolicyAudit(t *testing.T) { entityMetadata, ok := eventMetaData["entityMetadata"].(map[string]any) require.True(t, ok) assert.Equal(t, "jwt-user", entityMetadata["sub"]) - assert.Equal(t, `["admin","user"]`, entityMetadata["realm_access.roles"]) - assert.Equal(t, "true", entityMetadata["email_verified"]) + assert.Equal(t, []any{"admin", "user"}, entityMetadata["roles"]) + assert.Equal(t, true, entityMetadata["emailVerified"]) } func TestDeferredRewrapSuccess(t *testing.T) { diff --git a/service/pkg/config/config_test.go b/service/pkg/config/config_test.go index 105d7c6a5a..6c7778d030 100644 --- a/service/pkg/config/config_test.go +++ b/service/pkg/config/config_test.go @@ -390,9 +390,6 @@ server: logger: level: warn audit: - audited_entity_jwt_claims: - - sub - - realm_access.roles jwt_claim_mappings: - claim: email_verified path: eventMetaData.requester.emailVerified @@ -401,7 +398,6 @@ audit: // Values from file assert.Equal(t, 9090, cfg.Server.Port) assert.Equal(t, "warn", cfg.Logger.Level) - assert.Equal(t, []string{"sub", "realm_access.roles"}, cfg.Audit.AuditedEntityJWTClaims) assert.Equal(t, []auditcfg.JWTClaimMapping{ {Claim: "email_verified", Path: "eventMetaData.requester.emailVerified"}, }, cfg.Audit.JWTClaimMappings) From 903912a4b2cbb52f6cfbf6158e1480f8abdbc235 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 13:15:05 -0700 Subject: [PATCH 03/20] fix(audit): fail fast on invalid claim mappings Signed-off-by: jakedoublev --- service/logger/audit/config.go | 17 +++ service/logger/audit/enrichment.go | 18 --- service/logger/audit/logger.go | 8 +- service/logger/audit/logger_test.go | 49 ++++++- service/logger/audit/schema.go | 191 ++++++++++++++++++++++++++++ service/logger/audit/schema_test.go | 40 ++++++ service/logger/audit/utils.go | 40 +++--- service/pkg/config/config.go | 3 + service/pkg/config/config_test.go | 24 ++++ service/pkg/server/services.go | 33 +++-- service/pkg/server/start.go | 4 +- 11 files changed, 370 insertions(+), 57 deletions(-) create mode 100644 service/logger/audit/schema.go create mode 100644 service/logger/audit/schema_test.go diff --git a/service/logger/audit/config.go b/service/logger/audit/config.go index 1dc8761abc..2c547620b7 100644 --- a/service/logger/audit/config.go +++ b/service/logger/audit/config.go @@ -1,5 +1,7 @@ package audit +import "fmt" + // JWTClaimMapping maps a JWT claim to a destination in the emitted audit log. // The destination path uses dot notation over the audit log output shape, e.g. // `eventMetaData.requester.sub`. @@ -14,3 +16,18 @@ type Config struct { // log output, preserving the original JSON value types where possible. JWTClaimMappings []JWTClaimMapping `mapstructure:"jwt_claim_mappings" json:"jwt_claim_mappings"` } + +func (c Config) Validate() error { + for idx, mapping := range c.JWTClaimMappings { + if mapping.Claim == "" { + return fmt.Errorf("jwt_claim_mappings[%d].claim is required", idx) + } + if mapping.Path == "" { + return fmt.Errorf("jwt_claim_mappings[%d].path is required", idx) + } + if err := validateClaimDestinationPath(mapping.Path); err != nil { + return fmt.Errorf("jwt_claim_mappings[%d].path: %w", idx, err) + } + } + return nil +} diff --git a/service/logger/audit/enrichment.go b/service/logger/audit/enrichment.go index 0ae0776c4a..260be1a7bd 100644 --- a/service/logger/audit/enrichment.go +++ b/service/logger/audit/enrichment.go @@ -9,16 +9,6 @@ import ( ctxAuth "github.com/opentdf/platform/service/pkg/auth" ) -var reservedClaimDestinationPaths = map[string]struct{}{ - "requestID": {}, - "timestamp": {}, - "action.type": {}, - "action.result": {}, - "object.type": {}, - "actor.id": {}, - "clientInfo.platform": {}, -} - func (a *Logger) buildLogEntry(ctx context.Context, event *EventObject) map[string]any { entry := event.logMap() a.applyJWTClaimEnrichment(ctx, entry) @@ -49,14 +39,6 @@ func (a *Logger) applyMappedJWTClaims(ctx context.Context, entry map[string]any, if mapping.Claim == "" || mapping.Path == "" { continue } - if _, reserved := reservedClaimDestinationPaths[mapping.Path]; reserved { - a.logger.ErrorContext(ctx, - "refusing to write JWT claim to reserved audit path", - slog.String("claim", mapping.Claim), - slog.String("path", mapping.Path), - ) - continue - } value := internaldotnotation.Get(claimsMap, mapping.Claim) if value == nil { diff --git a/service/logger/audit/logger.go b/service/logger/audit/logger.go index 4ac7a70c07..538882923e 100644 --- a/service/logger/audit/logger.go +++ b/service/logger/audit/logger.go @@ -62,9 +62,13 @@ func CreateAuditLogger(logger slog.Logger) *Logger { } } -// ApplyConfig stores the latest audit enrichment configuration. -func (a *Logger) ApplyConfig(cfg Config) { +// ApplyConfig validates and stores the latest audit enrichment configuration. +func (a *Logger) ApplyConfig(cfg Config) error { + if err := cfg.Validate(); err != nil { + return err + } a.config = cfg + return nil } func (a *Logger) With(key string, value string) *Logger { diff --git a/service/logger/audit/logger_test.go b/service/logger/audit/logger_test.go index 00a3e2bb6d..c818a54a44 100644 --- a/service/logger/audit/logger_test.go +++ b/service/logger/audit/logger_test.go @@ -368,13 +368,13 @@ func TestAuditJWTClaimMappingsApplyToPolicyAudit(t *testing.T) { ctx := ctxAuth.ContextWithAuthNInfo(createTestContext(t), nil, token, rawToken) logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { - l.ApplyConfig(Config{ + require.NoError(t, l.ApplyConfig(Config{ JWTClaimMappings: []JWTClaimMapping{ {Claim: "sub", Path: "eventMetaData.requester.sub"}, {Claim: "realm_access.roles", Path: "eventMetaData.requester.roles"}, {Claim: "email_verified", Path: "eventMetaData.requester.emailVerified"}, }, - }) + })) l.PolicyCRUDSuccess(ctx, policyCRUDParams) }) @@ -393,11 +393,11 @@ func TestAuditJWTClaimMappingsUseMetadataFallback(t *testing.T) { ctx := metadata.NewIncomingContext(createTestContext(t), metadata.Pairs(ctxAuth.AccessTokenKey, rawToken)) logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { - l.ApplyConfig(Config{ + require.NoError(t, l.ApplyConfig(Config{ JWTClaimMappings: []JWTClaimMapping{ {Claim: "sub", Path: "eventMetaData.requester.sub"}, }, - }) + })) l.PolicyCRUDSuccess(ctx, policyCRUDParams) }) @@ -414,13 +414,13 @@ func TestAuditJWTClaimMappingsCanWriteToEntityMetadata(t *testing.T) { ctx := ctxAuth.ContextWithAuthNInfo(createTestContext(t), nil, token, rawToken) logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { - l.ApplyConfig(Config{ + require.NoError(t, l.ApplyConfig(Config{ JWTClaimMappings: []JWTClaimMapping{ {Claim: "sub", Path: "eventMetaData.entityMetadata.sub"}, {Claim: "realm_access.roles", Path: "eventMetaData.entityMetadata.roles"}, {Claim: "email_verified", Path: "eventMetaData.entityMetadata.emailVerified"}, }, - }) + })) l.PolicyCRUDSuccess(ctx, policyCRUDParams) }) @@ -434,6 +434,43 @@ func TestAuditJWTClaimMappingsCanWriteToEntityMetadata(t *testing.T) { assert.Equal(t, true, entityMetadata["emailVerified"]) } +func TestAuditJWTClaimMappingsLeaveReservedFieldsUntouched(t *testing.T) { + token, rawToken := createTestJWTForAudit(t) + ctx := ctxAuth.ContextWithAuthNInfo(createTestContext(t), nil, token, rawToken) + + logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { + require.NoError(t, l.ApplyConfig(Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.requester.sub"}, + }, + })) + l.PolicyCRUDSuccess(ctx, policyCRUDParams) + }) + + payload := decodeAuditPayload(t, logEntry.Audit) + assert.Equal(t, TestRequestID.String(), payload["requestID"]) + + eventMetaData, ok := payload["eventMetaData"].(map[string]any) + require.True(t, ok) + requester, ok := eventMetaData["requester"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "jwt-user", requester["sub"]) +} + +func TestAuditApplyConfigRejectsReservedPaths(t *testing.T) { + l, _ := createTestLogger() + + err := l.ApplyConfig(Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "requestID"}, + }, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "jwt_claim_mappings[0].path") + require.ErrorContains(t, err, "reserved audit path") +} + func TestDeferredRewrapSuccess(t *testing.T) { logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { l.RewrapSuccess(ctx, rewrapParams) diff --git a/service/logger/audit/schema.go b/service/logger/audit/schema.go new file mode 100644 index 0000000000..93b0a87c05 --- /dev/null +++ b/service/logger/audit/schema.go @@ -0,0 +1,191 @@ +package audit + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +var ( + errReservedAuditPath = errors.New("reserved audit path") + errUnknownAuditPath = errors.New("unknown audit path") + errAuditContainerPath = errors.New("audit path resolves to a container") + + auditClaimDestinationSchema = mustBuildAuditPathSchema(reflect.TypeOf(EventObject{})) +) + +type auditFieldOptions struct { + name string + reserved bool + extensible bool +} + +type auditPathSchema struct { + children map[string]*auditPathSchema + reserved bool + extensible bool +} + +func mustBuildAuditPathSchema(rootType reflect.Type) *auditPathSchema { + root, err := buildAuditPathSchema(rootType) + if err != nil { + panic(err) + } + return root +} + +func buildAuditPathSchema(rootType reflect.Type) (*auditPathSchema, error) { + rootType = indirectType(rootType) + if rootType.Kind() != reflect.Struct { + return nil, fmt.Errorf("audit schema root must be a struct, got %s", rootType.Kind()) + } + + root := &auditPathSchema{ + children: make(map[string]*auditPathSchema), + } + if err := addAuditSchemaFields(root, rootType); err != nil { + return nil, err + } + return root, nil +} + +func addAuditSchemaFields(parent *auditPathSchema, structType reflect.Type) error { + for i := range structType.NumField() { + field := structType.Field(i) + if !field.IsExported() { + continue + } + + opts, ok := parseAuditFieldOptions(field) + if !ok { + continue + } + if _, exists := parent.children[opts.name]; exists { + return fmt.Errorf("duplicate audit schema path %q on %s", opts.name, structType) + } + + child := &auditPathSchema{ + children: make(map[string]*auditPathSchema), + reserved: opts.reserved, + extensible: opts.extensible, + } + parent.children[opts.name] = child + + fieldType := indirectType(field.Type) + if fieldType.Kind() == reflect.Struct { + if err := addAuditSchemaFields(child, fieldType); err != nil { + return err + } + } + if len(child.children) == 0 { + child.children = nil + } + } + + return nil +} + +func parseAuditFieldOptions(field reflect.StructField) (auditFieldOptions, bool) { + tag := field.Tag.Get("audit") + if tag == "-" { + return auditFieldOptions{}, false + } + + name, reserved, extensible := parseAuditTag(tag) + if name == "" { + name = parseJSONFieldName(field) + } + if name == "" { + return auditFieldOptions{}, false + } + + return auditFieldOptions{ + name: name, + reserved: reserved, + extensible: extensible, + }, true +} + +func parseAuditTag(tag string) (string, bool, bool) { + if tag == "" { + return "", false, false + } + + parts := strings.Split(tag, ",") + name := parts[0] + var ( + reserved bool + extensible bool + ) + for _, option := range parts[1:] { + switch option { + case "reserved": + reserved = true + case "extensible": + extensible = true + } + } + return name, reserved, extensible +} + +func parseJSONFieldName(field reflect.StructField) string { + tag := field.Tag.Get("json") + if tag == "-" { + return "" + } + if tag == "" { + return field.Name + } + + name, _, _ := strings.Cut(tag, ",") + if name == "" { + return field.Name + } + return name +} + +func indirectType(t reflect.Type) reflect.Type { + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + return t +} + +func validateClaimDestinationPath(path string) error { + if path == "" { + return fmt.Errorf("%w: empty path", errUnknownAuditPath) + } + + segments := strings.Split(path, ".") + current := auditClaimDestinationSchema + for idx, segment := range segments { + if segment == "" { + return fmt.Errorf("%w: %s", errUnknownAuditPath, path) + } + + child, ok := current.children[segment] + if !ok { + if current.extensible { + return nil + } + return fmt.Errorf("%w: %s", errUnknownAuditPath, path) + } + + isLast := idx == len(segments)-1 + if isLast { + switch { + case child.reserved: + return fmt.Errorf("%w: %s", errReservedAuditPath, path) + case child.extensible || len(child.children) > 0: + return fmt.Errorf("%w: %s", errAuditContainerPath, path) + default: + return nil + } + } + + current = child + } + + return nil +} diff --git a/service/logger/audit/schema_test.go b/service/logger/audit/schema_test.go new file mode 100644 index 0000000000..123cf50053 --- /dev/null +++ b/service/logger/audit/schema_test.go @@ -0,0 +1,40 @@ +package audit + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidateClaimDestinationPath(t *testing.T) { + t.Run("allows writable leaf paths", func(t *testing.T) { + require.NoError(t, validateClaimDestinationPath("object.id")) + require.NoError(t, validateClaimDestinationPath("clientInfo.requestIP")) + }) + + t.Run("allows nested paths below extensible maps", func(t *testing.T) { + require.NoError(t, validateClaimDestinationPath("eventMetaData.requester.sub")) + require.NoError(t, validateClaimDestinationPath("original.request.headers.user")) + }) + + t.Run("rejects reserved paths", func(t *testing.T) { + err := validateClaimDestinationPath("requestID") + require.ErrorIs(t, err, errReservedAuditPath) + + err = validateClaimDestinationPath("action.result") + require.ErrorIs(t, err, errReservedAuditPath) + }) + + t.Run("rejects container paths", func(t *testing.T) { + err := validateClaimDestinationPath("eventMetaData") + require.ErrorIs(t, err, errAuditContainerPath) + + err = validateClaimDestinationPath("object") + require.ErrorIs(t, err, errAuditContainerPath) + }) + + t.Run("rejects unknown top level paths", func(t *testing.T) { + err := validateClaimDestinationPath("requester.sub") + require.ErrorIs(t, err, errUnknownAuditPath) + }) +} diff --git a/service/logger/audit/utils.go b/service/logger/audit/utils.go index e87f18ebb4..a65a7adac6 100644 --- a/service/logger/audit/utils.go +++ b/service/logger/audit/utils.go @@ -16,16 +16,16 @@ type auditEventMetadata map[string]any // event type EventObject struct { - Object auditEventObject `json:"object"` - Action eventAction `json:"action"` - Actor auditEventActor `json:"actor"` - EventMetaData auditEventMetadata `json:"eventMetaData"` - ClientInfo eventClientInfo `json:"clientInfo"` + Object auditEventObject `json:"object" audit:"object"` + Action eventAction `json:"action" audit:"action"` + Actor auditEventActor `json:"actor" audit:"actor"` + EventMetaData auditEventMetadata `json:"eventMetaData" audit:"eventMetaData,extensible"` + ClientInfo eventClientInfo `json:"clientInfo" audit:"clientInfo"` - Original map[string]any `json:"original,omitempty"` - Updated map[string]any `json:"updated,omitempty"` - RequestID uuid.UUID `json:"requestId"` - Timestamp string `json:"timestamp"` + Original map[string]any `json:"original,omitempty" audit:"original,extensible"` + Updated map[string]any `json:"updated,omitempty" audit:"updated,extensible"` + RequestID uuid.UUID `json:"requestId" audit:"requestID,reserved"` + Timestamp string `json:"timestamp" audit:"timestamp,reserved"` } func (e EventObject) LogValue() slog.Value { @@ -76,10 +76,10 @@ func (e EventObject) logMap() map[string]any { // event.object type auditEventObject struct { - Type ObjectType `json:"type"` - ID string `json:"id"` - Name string `json:"name,omitempty"` - Attributes eventObjectAttributes `json:"attributes,omitempty"` + Type ObjectType `json:"type" audit:"type,reserved"` + ID string `json:"id" audit:"id"` + Name string `json:"name,omitempty" audit:"name"` + Attributes eventObjectAttributes `json:"attributes,omitempty" audit:"attributes"` } func (e auditEventObject) LogValue() slog.Value { @@ -106,8 +106,8 @@ func (e eventObjectAttributes) LogValue() slog.Value { // event.action type eventAction struct { - Type ActionType `json:"type"` - Result ActionResult `json:"result"` + Type ActionType `json:"type" audit:"type,reserved"` + Result ActionResult `json:"result" audit:"result,reserved"` } func (e eventAction) LogValue() slog.Value { @@ -118,8 +118,8 @@ func (e eventAction) LogValue() slog.Value { // event.actor type auditEventActor struct { - ID string `json:"id"` - Attributes []any `json:"attributes"` + ID string `json:"id" audit:"id,reserved"` + Attributes []any `json:"attributes" audit:"attributes"` } func (e auditEventActor) LogValue() slog.Value { @@ -130,9 +130,9 @@ func (e auditEventActor) LogValue() slog.Value { // event.clientInfo type eventClientInfo struct { - UserAgent string `json:"userAgent"` - Platform string `json:"platform"` - RequestIP string `json:"requestIp"` + UserAgent string `json:"userAgent" audit:"userAgent"` + Platform string `json:"platform" audit:"platform,reserved"` + RequestIP string `json:"requestIp" audit:"requestIP"` } func (e eventClientInfo) LogValue() slog.Value { diff --git a/service/pkg/config/config.go b/service/pkg/config/config.go index acbc1f3b28..355cca081d 100644 --- a/service/pkg/config/config.go +++ b/service/pkg/config/config.go @@ -299,6 +299,9 @@ func (c *Config) Reload(ctx context.Context) error { if err := validator.New().Struct(c); err != nil { return errors.Join(err, ErrUnmarshallingConfig) } + if err := c.Audit.Validate(); err != nil { + return errors.Join(err, ErrUnmarshallingConfig) + } if skew := c.Security.ClockSkew(); skew > DefaultUnsafeClockSkew { slog.WarnContext(ctx, diff --git a/service/pkg/config/config_test.go b/service/pkg/config/config_test.go index 6c7778d030..506067e09c 100644 --- a/service/pkg/config/config_test.go +++ b/service/pkg/config/config_test.go @@ -357,6 +357,7 @@ func TestLoad_Precedence(t *testing.T) { setupLoaders func(t *testing.T, configFile string) []Loader envVars map[string]string err error + errContains string fileContent string asserts func(t *testing.T, cfg *Config) }{ @@ -434,6 +435,24 @@ special_key: assert.Equal(t, []string{"all"}, cfg.Mode) }, }, + { + name: "invalid audit mapping path fails validation", + setupLoaders: func(t *testing.T, configFile string) []Loader { + file, err := NewConfigFileLoader(configKey, configFile) + require.NoError(t, err) + defaults, err := NewDefaultSettingsLoader() + require.NoError(t, err) + return []Loader{file, defaults} + }, + err: ErrUnmarshallingConfig, + errContains: "reserved audit path", + fileContent: ` +audit: + jwt_claim_mappings: + - claim: sub + path: requestID +`, + }, { name: "env overrides file and defaults except client_id", setupLoaders: func(t *testing.T, configFile string) []Loader { @@ -670,6 +689,11 @@ logger: // Assertions if tc.err != nil { require.Error(t, err) + require.ErrorIs(t, err, tc.err) + if tc.errContains != "" { + require.ErrorContains(t, err, tc.errContains) + } + require.Nil(t, cfg) } else { require.NoError(t, err) require.NotNil(t, cfg) diff --git a/service/pkg/server/services.go b/service/pkg/server/services.go index 7a2d3df3a2..2ae39f1665 100644 --- a/service/pkg/server/services.go +++ b/service/pkg/server/services.go @@ -164,16 +164,9 @@ func startServices(ctx context.Context, params startServicesParams) (func(), err // If ns has log_level in config, create new logger with that level if err == nil { - if extractedLogLevel != cfg.Logger.Level { - slog.Debug("configuring logger") - newLoggerConfig := cfg.Logger - newLoggerConfig.Level = extractedLogLevel - newSvcLogger, err := logging.NewLogger(newLoggerConfig) - // only assign if logger successfully created - if err == nil { - newSvcLogger.Audit.ApplyConfig(cfg.Audit) - svcLogger = newSvcLogger.With("namespace", ns) - } + svcLogger, err = buildNamespaceLogger(svcLogger, cfg, ns, extractedLogLevel) + if err != nil { + return func() {}, err } } @@ -288,6 +281,26 @@ func extractServiceLoggerConfig(cfg config.ServiceConfig) (string, error) { return "", fmt.Errorf("could not decode service log level: %w", err) } +func buildNamespaceLogger(baseLogger *logging.Logger, cfg *config.Config, ns, level string) (*logging.Logger, error) { + if level == cfg.Logger.Level { + return baseLogger, nil + } + + slog.Debug("configuring logger") + newLoggerConfig := cfg.Logger + newLoggerConfig.Level = level + + if namespaceLogger, loggerErr := logging.NewLogger(newLoggerConfig); loggerErr == nil { + if err := namespaceLogger.Audit.ApplyConfig(cfg.Audit); err != nil { + return nil, fmt.Errorf("could not apply audit config for namespace %s: %w", ns, err) + } + return namespaceLogger.With("namespace", ns), nil + } + + // Keep the existing logger if the override could not be created. + return baseLogger, nil +} + // newServiceDBClient creates a new database client for the specified namespace. // It initializes the client with the provided context, logger configuration, database configuration, // namespace, and migrations. It returns the created client and any error encountered during creation. diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index a9a39e54dd..ea023d5e31 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -112,7 +112,9 @@ func Start(f ...StartOptions) error { if err != nil { return fmt.Errorf("could not start logger: %w", err) } - logger.Audit.ApplyConfig(cfg.Audit) + if err := logger.Audit.ApplyConfig(cfg.Audit); err != nil { + return fmt.Errorf("could not apply audit config: %w", err) + } // Set default for places we can't pass the logger slog.SetDefault(logger.Logger) From d9595b6820dc717ce47e03f74312361ef49d1376 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 13:25:42 -0700 Subject: [PATCH 04/20] fix(audit): allow top-level claim destinations Signed-off-by: jakedoublev --- service/logger/audit/schema.go | 3 ++- service/logger/audit/schema_test.go | 14 ++++++++++++-- service/pkg/config/config_test.go | 21 +++++++++++++++++++++ 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/service/logger/audit/schema.go b/service/logger/audit/schema.go index 93b0a87c05..ade402d32c 100644 --- a/service/logger/audit/schema.go +++ b/service/logger/audit/schema.go @@ -42,7 +42,8 @@ func buildAuditPathSchema(rootType reflect.Type) (*auditPathSchema, error) { } root := &auditPathSchema{ - children: make(map[string]*auditPathSchema), + children: make(map[string]*auditPathSchema), + extensible: true, } if err := addAuditSchemaFields(root, rootType); err != nil { return nil, err diff --git a/service/logger/audit/schema_test.go b/service/logger/audit/schema_test.go index 123cf50053..9e58bbbc16 100644 --- a/service/logger/audit/schema_test.go +++ b/service/logger/audit/schema_test.go @@ -17,6 +17,11 @@ func TestValidateClaimDestinationPath(t *testing.T) { require.NoError(t, validateClaimDestinationPath("original.request.headers.user")) }) + t.Run("allows top level additions", func(t *testing.T) { + require.NoError(t, validateClaimDestinationPath("banana")) + require.NoError(t, validateClaimDestinationPath("banana.requester.sub")) + }) + t.Run("rejects reserved paths", func(t *testing.T) { err := validateClaimDestinationPath("requestID") require.ErrorIs(t, err, errReservedAuditPath) @@ -33,8 +38,13 @@ func TestValidateClaimDestinationPath(t *testing.T) { require.ErrorIs(t, err, errAuditContainerPath) }) - t.Run("rejects unknown top level paths", func(t *testing.T) { - err := validateClaimDestinationPath("requester.sub") + t.Run("rejects unknown nested paths below closed containers", func(t *testing.T) { + err := validateClaimDestinationPath("object.extra.foo") + require.ErrorIs(t, err, errUnknownAuditPath) + }) + + t.Run("rejects leading dot paths", func(t *testing.T) { + err := validateClaimDestinationPath(".banana") require.ErrorIs(t, err, errUnknownAuditPath) }) } diff --git a/service/pkg/config/config_test.go b/service/pkg/config/config_test.go index 506067e09c..129918217d 100644 --- a/service/pkg/config/config_test.go +++ b/service/pkg/config/config_test.go @@ -406,6 +406,27 @@ audit: assert.Equal(t, []string{"all"}, cfg.Mode) }, }, + { + name: "top level audit mapping paths pass validation", + setupLoaders: func(t *testing.T, configFile string) []Loader { + file, err := NewConfigFileLoader(configKey, configFile) + require.NoError(t, err) + defaults, err := NewDefaultSettingsLoader() + require.NoError(t, err) + return []Loader{file, defaults} + }, + fileContent: ` +audit: + jwt_claim_mappings: + - claim: email_verified + path: banana.requester.emailVerified +`, + asserts: func(t *testing.T, cfg *Config) { + assert.Equal(t, []auditcfg.JWTClaimMapping{ + {Claim: "email_verified", Path: "banana.requester.emailVerified"}, + }, cfg.Audit.JWTClaimMappings) + }, + }, { name: "file with extras and defaults", setupLoaders: func(t *testing.T, configFile string) []Loader { From e2cb8aa803a80fd33d2e2f70167440ce836dde0d Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 13:33:18 -0700 Subject: [PATCH 05/20] chore(audit): simplify dotnotation import alias Signed-off-by: jakedoublev --- service/logger/audit/enrichment.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/service/logger/audit/enrichment.go b/service/logger/audit/enrichment.go index 260be1a7bd..90eabd6b73 100644 --- a/service/logger/audit/enrichment.go +++ b/service/logger/audit/enrichment.go @@ -5,7 +5,7 @@ import ( "log/slog" "reflect" - internaldotnotation "github.com/opentdf/platform/service/internal/dotnotation" + dotnotation "github.com/opentdf/platform/service/internal/dotnotation" ctxAuth "github.com/opentdf/platform/service/pkg/auth" ) @@ -40,12 +40,12 @@ func (a *Logger) applyMappedJWTClaims(ctx context.Context, entry map[string]any, continue } - value := internaldotnotation.Get(claimsMap, mapping.Claim) + value := dotnotation.Get(claimsMap, mapping.Claim) if value == nil { continue } - if err := internaldotnotation.Set(entry, mapping.Path, normalizeAuditValue(value)); err != nil { + if err := dotnotation.Set(entry, mapping.Path, normalizeAuditValue(value)); err != nil { a.logger.ErrorContext(ctx, "failed to apply JWT claim mapping to audit log", slog.String("claim", mapping.Claim), From ffcc5f2de92170db341fd2d3d9a78ba57347008e Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 13:38:32 -0700 Subject: [PATCH 06/20] test(audit): cover more claim destination paths Signed-off-by: jakedoublev --- service/logger/audit/logger_test.go | 44 +++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/service/logger/audit/logger_test.go b/service/logger/audit/logger_test.go index c818a54a44..228e3063f7 100644 --- a/service/logger/audit/logger_test.go +++ b/service/logger/audit/logger_test.go @@ -434,6 +434,50 @@ func TestAuditJWTClaimMappingsCanWriteToEntityMetadata(t *testing.T) { assert.Equal(t, true, entityMetadata["emailVerified"]) } +func TestAuditJWTClaimMappingsCoverNamedAndUnnamedPaths(t *testing.T) { + token, rawToken := createTestJWTForAudit(t) + ctx := ctxAuth.ContextWithAuthNInfo(createTestContext(t), nil, token, rawToken) + + logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { + require.NoError(t, l.ApplyConfig(Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "object.name"}, + {Claim: "realm_access.roles", Path: "actor.attributes"}, + {Claim: "sub", Path: "original.request.jwt.sub"}, + {Claim: "sub", Path: "banana"}, + {Claim: "email_verified", Path: "kiwi.requester.emailVerified"}, + }, + })) + l.PolicyCRUDSuccess(ctx, policyCRUDParams) + }) + + payload := decodeAuditPayload(t, logEntry.Audit) + + object, ok := payload["object"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "jwt-user", object["name"]) + + actor, ok := payload["actor"].(map[string]any) + require.True(t, ok) + assert.Equal(t, []any{"admin", "user"}, actor["attributes"]) + + original, ok := payload["original"].(map[string]any) + require.True(t, ok) + request, ok := original["request"].(map[string]any) + require.True(t, ok) + jwt, ok := request["jwt"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "jwt-user", jwt["sub"]) + + assert.Equal(t, "jwt-user", payload["banana"]) + + kiwi, ok := payload["kiwi"].(map[string]any) + require.True(t, ok) + requester, ok := kiwi["requester"].(map[string]any) + require.True(t, ok) + assert.Equal(t, true, requester["emailVerified"]) +} + func TestAuditJWTClaimMappingsLeaveReservedFieldsUntouched(t *testing.T) { token, rawToken := createTestJWTForAudit(t) ctx := ctxAuth.ContextWithAuthNInfo(createTestContext(t), nil, token, rawToken) From 35103dc037d812ee9d502f7f01978b31bd209403 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 14:13:20 -0700 Subject: [PATCH 07/20] fix(auth): rehydrate IPC auth context Signed-off-by: jakedoublev --- service/internal/auth/authn.go | 5 ++ .../authn_ipc_metadata_interceptor_test.go | 54 ++++++++++++-- service/logger/audit/logger_test.go | 4 +- service/pkg/auth/context_auth.go | 56 ++++++++------- service/pkg/auth/context_auth_test.go | 70 +++++++++++++++++-- 5 files changed, 149 insertions(+), 40 deletions(-) diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index b5c68ba837..99768fe0b1 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -442,6 +442,7 @@ func IPCMetadataClientInterceptor(log *logger.Logger) connect.UnaryInterceptorFu // 1. verifies the token in the metadata // 2. reauthorizes the token if the route is in the list // 3. translates known IPC Connect request headers back to context metadata for downstream consumers +// 4. rehydrates auth context from incoming metadata for non-reauth IPC routes func (a Authentication) IPCUnaryServerInterceptor() connect.UnaryInterceptorFunc { interceptor := func(next connect.UnaryFunc) connect.UnaryFunc { return connect.UnaryFunc(func( @@ -467,6 +468,10 @@ func (a Authentication) IPCUnaryServerInterceptor() connect.UnaryInterceptorFunc if err != nil { return nil, err } + nextCtx, err = ctxAuth.RehydrateAccessTokenFromIncomingMetadata(nextCtx, a.logger) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, errors.New("failed to rehydrate IPC authentication context")) + } return next(nextCtx, req) }) } diff --git a/service/internal/auth/authn_ipc_metadata_interceptor_test.go b/service/internal/auth/authn_ipc_metadata_interceptor_test.go index 16d15c3e26..0ac1603835 100644 --- a/service/internal/auth/authn_ipc_metadata_interceptor_test.go +++ b/service/internal/auth/authn_ipc_metadata_interceptor_test.go @@ -2,6 +2,7 @@ package auth import ( "context" + "errors" "testing" "connectrpc.com/connect" @@ -184,6 +185,10 @@ func (m *mockAnyRequest) Any() any { func TestIPCUnaryServerInterceptor(t *testing.T) { testLogger := logger.CreateTestLogger() + validJWT, err := jwt.NewBuilder().Subject("ipc-user").Build() + require.NoError(t, err) + validRawToken, err := jwt.Sign(validJWT, jwt.WithInsecureNoSignature()) + require.NoError(t, err) // Create a minimal authentication instance auth := &Authentication{ @@ -201,7 +206,7 @@ func TestIPCUnaryServerInterceptor(t *testing.T) { setupRequest: func() connect.AnyRequest { req := connect.NewRequest(&kas.PublicKeyRequest{}) req.Header().Set(canonicalIPCHeaderClientID, "test-client-from-header") - req.Header().Set(canonicalIPCHeaderAccessToken, "test-token-from-header") + req.Header().Set(canonicalIPCHeaderAccessToken, string(validRawToken)) return &mockAnyRequest{ Request: req, isClient: false, @@ -225,7 +230,7 @@ func TestIPCUnaryServerInterceptor(t *testing.T) { setupRequest: func() connect.AnyRequest { req := connect.NewRequest(&kas.PublicKeyRequest{}) req.Header().Set(canonicalIPCHeaderClientID, "merged-client-id") - req.Header().Set(canonicalIPCHeaderAccessToken, "merged-token") + req.Header().Set(canonicalIPCHeaderAccessToken, string(validRawToken)) return &mockAnyRequest{ Request: req, isClient: false, @@ -251,8 +256,13 @@ func TestIPCUnaryServerInterceptor(t *testing.T) { for _, key := range tt.expectedIncomingMDKeys { assert.NotEmpty(t, md.Get(key), "metadata key %s should exist", key) } + retrievedJWT := ctxAuth.GetAccessTokenFromContext(postInterceptorCtx, testLogger) + require.NotNil(t, retrievedJWT) + assert.Equal(t, "ipc-user", retrievedJWT.Subject()) + assert.Equal(t, string(validRawToken), ctxAuth.GetRawAccessTokenFromContext(postInterceptorCtx, testLogger)) } else { assert.Zero(t, md.Len()) + assert.Nil(t, ctxAuth.GetAccessTokenFromContext(postInterceptorCtx, testLogger)) } return connect.NewResponse(&kas.PublicKeyResponse{}), nil } @@ -267,19 +277,22 @@ func TestIPCUnaryServerInterceptor(t *testing.T) { func TestIPCUnaryServerInterceptor_Integration(t *testing.T) { testLogger := logger.CreateTestLogger() + mockJWT, err := jwt.NewBuilder().Subject("integration-user").Build() + require.NoError(t, err) + rawToken, err := jwt.Sign(mockJWT, jwt.WithInsecureNoSignature()) + require.NoError(t, err) auth := &Authentication{ logger: testLogger, ipcReauthRoutes: []string{}, } - t.Run("clientID and access token from headers available in context metadata", func(t *testing.T) { + t.Run("clientID and access token from headers available in context metadata and auth context", func(t *testing.T) { clientID := "integration-client-id" - accessToken := "integration-access-token" req := connect.NewRequest(&kas.PublicKeyRequest{}) req.Header().Set(canonicalIPCHeaderClientID, clientID) - req.Header().Set(canonicalIPCHeaderAccessToken, accessToken) + req.Header().Set(canonicalIPCHeaderAccessToken, string(rawToken)) wrappedReq := &mockAnyRequest{ Request: req, @@ -290,7 +303,7 @@ func TestIPCUnaryServerInterceptor_Integration(t *testing.T) { ctx := t.Context() - var receivedClientID, receivedAccessToken string + var receivedClientID, receivedAccessToken, receivedSubject string mockNext := func(ctx context.Context, _ connect.AnyRequest) (connect.AnyResponse, error) { md, ok := metadata.FromIncomingContext(ctx) require.True(t, ok) @@ -302,6 +315,9 @@ func TestIPCUnaryServerInterceptor_Integration(t *testing.T) { if len(accessTokens) > 0 { receivedAccessToken = accessTokens[0] } + retrievedJWT := ctxAuth.GetAccessTokenFromContext(ctx, testLogger) + require.NotNil(t, retrievedJWT) + receivedSubject = retrievedJWT.Subject() return connect.NewResponse(&kas.PublicKeyResponse{}), nil } @@ -310,6 +326,30 @@ func TestIPCUnaryServerInterceptor_Integration(t *testing.T) { require.NoError(t, err) assert.Equal(t, clientID, receivedClientID) - assert.Equal(t, accessToken, receivedAccessToken) + assert.Equal(t, string(rawToken), receivedAccessToken) + assert.Equal(t, "integration-user", receivedSubject) + }) + + t.Run("invalid access token header fails rehydration", func(t *testing.T) { + req := connect.NewRequest(&kas.PublicKeyRequest{}) + req.Header().Set(canonicalIPCHeaderAccessToken, "not-a-jwt") + + wrappedReq := &mockAnyRequest{ + Request: req, + isClient: false, + } + + interceptor := auth.IPCUnaryServerInterceptor() + interceptorFunc := interceptor(func(context.Context, connect.AnyRequest) (connect.AnyResponse, error) { + t.Fatal("next handler should not be called on invalid token") + return nil, errors.New("unreachable") + }) + + _, err := interceptorFunc(t.Context(), wrappedReq) + require.Error(t, err) + + var connectErr *connect.Error + require.ErrorAs(t, err, &connectErr) + assert.Equal(t, connect.CodeInternal, connectErr.Code()) }) } diff --git a/service/logger/audit/logger_test.go b/service/logger/audit/logger_test.go index 228e3063f7..7ef83b1638 100644 --- a/service/logger/audit/logger_test.go +++ b/service/logger/audit/logger_test.go @@ -388,9 +388,11 @@ func TestAuditJWTClaimMappingsApplyToPolicyAudit(t *testing.T) { assert.Equal(t, true, requester["emailVerified"]) } -func TestAuditJWTClaimMappingsUseMetadataFallback(t *testing.T) { +func TestAuditJWTClaimMappingsUseRehydratedAuthContext(t *testing.T) { _, rawToken := createTestJWTForAudit(t) ctx := metadata.NewIncomingContext(createTestContext(t), metadata.Pairs(ctxAuth.AccessTokenKey, rawToken)) + ctx, err := ctxAuth.RehydrateAccessTokenFromIncomingMetadata(ctx, nil) + require.NoError(t, err) logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { require.NoError(t, l.ApplyConfig(Config{ diff --git a/service/pkg/auth/context_auth.go b/service/pkg/auth/context_auth.go index 904ba63e8c..d1ac292bc6 100644 --- a/service/pkg/auth/context_auth.go +++ b/service/pkg/auth/context_auth.go @@ -3,6 +3,7 @@ package auth import ( "context" "errors" + "fmt" "strings" "github.com/lestrrat-go/jwx/v2/jwk" @@ -30,7 +31,9 @@ type authContext struct { rawToken string } -type contextErrorLogger interface { +// optionalErrorLogger keeps pkg/auth decoupled from the concrete logger package +// and helps avoid an import cycle. +type optionalErrorLogger interface { ErrorContext(context.Context, string, ...any) } @@ -42,7 +45,7 @@ func ContextWithAuthNInfo(ctx context.Context, key jwk.Key, accessToken jwt.Toke }) } -func getContextDetails(ctx context.Context, l contextErrorLogger) *authContext { +func getContextDetails(ctx context.Context, l optionalErrorLogger) *authContext { key := ctx.Value(authnContextKey) if key == nil { return nil @@ -57,58 +60,59 @@ func getContextDetails(ctx context.Context, l contextErrorLogger) *authContext { return nil } -func GetJWKFromContext(ctx context.Context, l contextErrorLogger) jwk.Key { +func GetJWKFromContext(ctx context.Context, l optionalErrorLogger) jwk.Key { if c := getContextDetails(ctx, l); c != nil { return c.key } return nil } -func GetAccessTokenFromContext(ctx context.Context, l contextErrorLogger) jwt.Token { +func GetAccessTokenFromContext(ctx context.Context, l optionalErrorLogger) jwt.Token { if c := getContextDetails(ctx, l); c != nil { if c.accessToken != nil { return c.accessToken } } - - rawToken := GetRawAccessTokenFromContext(ctx, l) - if rawToken == "" { - return nil - } - - parsed, err := jwt.Parse([]byte(rawToken), jwt.WithVerify(false), jwt.WithValidate(false)) - if err != nil { - if l != nil { - l.ErrorContext(ctx, "failed to parse access token from context", "error", err) - } - return nil - } - - return parsed + return nil } -func GetRawAccessTokenFromContext(ctx context.Context, l contextErrorLogger) string { +func GetRawAccessTokenFromContext(ctx context.Context, l optionalErrorLogger) string { if c := getContextDetails(ctx, l); c != nil { if c.rawToken != "" { return c.rawToken } } + return "" +} - if rawToken := getRawAccessTokenFromMetadata(ctx, true); rawToken != "" { - return rawToken +// RehydrateAccessTokenFromIncomingMetadata reconstructs auth context from incoming metadata +// so downstream code can use the normal accessors without transport-specific fallbacks. +func RehydrateAccessTokenFromIncomingMetadata(ctx context.Context, l optionalErrorLogger) (context.Context, error) { + if c := getContextDetails(ctx, l); c != nil && c.accessToken != nil && c.rawToken != "" { + return ctx, nil } - if rawToken := getRawAccessTokenFromMetadata(ctx, false); rawToken != "" { - return rawToken + + rawToken := getRawAccessTokenFromMetadata(ctx, true) + if rawToken == "" { + return ctx, nil } - return "" + parsed, err := jwt.Parse([]byte(rawToken), jwt.WithVerify(false), jwt.WithValidate(false)) + if err != nil { + if l != nil { + l.ErrorContext(ctx, "failed to rehydrate access token from incoming metadata", "error", err) + } + return ctx, fmt.Errorf("rehydrate access token from incoming metadata: %w", err) + } + + return ContextWithAuthNInfo(ctx, nil, parsed, rawToken), nil } // EnrichIncomingContextMetadataWithAuthn adds the access token and client ID to incoming context metadata // // Adding the authn info to gRPC metadata propagates it across services rather than strictly // in-process within Go alone -func EnrichIncomingContextMetadataWithAuthn(ctx context.Context, l contextErrorLogger, clientID string) context.Context { +func EnrichIncomingContextMetadataWithAuthn(ctx context.Context, l optionalErrorLogger, clientID string) context.Context { rawToken := GetRawAccessTokenFromContext(ctx, l) md, ok := metadata.FromIncomingContext(ctx) diff --git a/service/pkg/auth/context_auth_test.go b/service/pkg/auth/context_auth_test.go index e09b193b3b..f88cc4b2fd 100644 --- a/service/pkg/auth/context_auth_test.go +++ b/service/pkg/auth/context_auth_test.go @@ -63,21 +63,21 @@ func TestGetRawAccessTokenFromContext(t *testing.T) { assert.Equal(t, rawToken, retrievedRawToken, "Retrieved raw token should match the mock raw token") } -func TestGetRawAccessTokenFromContextFallsBackToMetadata(t *testing.T) { +func TestGetRawAccessTokenFromContextDoesNotFallbackToMetadata(t *testing.T) { t.Run("incoming access token metadata", func(t *testing.T) { ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs(AccessTokenKey, "incoming-token")) retrievedRawToken := GetRawAccessTokenFromContext(ctx, nil) - assert.Equal(t, "incoming-token", retrievedRawToken) + assert.Empty(t, retrievedRawToken) }) t.Run("outgoing authorization metadata", func(t *testing.T) { ctx := metadata.NewOutgoingContext(t.Context(), metadata.Pairs("Authorization", "Bearer outgoing-token")) retrievedRawToken := GetRawAccessTokenFromContext(ctx, nil) - assert.Equal(t, "outgoing-token", retrievedRawToken) + assert.Empty(t, retrievedRawToken) }) } -func TestGetAccessTokenFromContextFallsBackToMetadata(t *testing.T) { +func TestGetAccessTokenFromContextDoesNotFallbackToMetadata(t *testing.T) { mockJWT, err := jwt.NewBuilder(). Subject("metadata-user"). Claim("roles", []string{"admin"}). @@ -89,8 +89,66 @@ func TestGetAccessTokenFromContextFallsBackToMetadata(t *testing.T) { ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs(AccessTokenKey, string(rawToken))) retrievedJWT := GetAccessTokenFromContext(ctx, nil) - require.NotNil(t, retrievedJWT) - assert.Equal(t, "metadata-user", retrievedJWT.Subject()) + assert.Nil(t, retrievedJWT) +} + +func TestRehydrateAccessTokenFromIncomingMetadata(t *testing.T) { + t.Run("rehydrates from access token metadata", func(t *testing.T) { + mockJWT, err := jwt.NewBuilder(). + Subject("metadata-user"). + Claim("roles", []string{"admin"}). + Build() + require.NoError(t, err) + + rawToken, err := jwt.Sign(mockJWT, jwt.WithInsecureNoSignature()) + require.NoError(t, err) + + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs(AccessTokenKey, string(rawToken))) + rehydratedCtx, err := RehydrateAccessTokenFromIncomingMetadata(ctx, nil) + require.NoError(t, err) + + retrievedJWT := GetAccessTokenFromContext(rehydratedCtx, nil) + require.NotNil(t, retrievedJWT) + assert.Equal(t, "metadata-user", retrievedJWT.Subject()) + assert.Equal(t, string(rawToken), GetRawAccessTokenFromContext(rehydratedCtx, nil)) + }) + + t.Run("uses authorization metadata", func(t *testing.T) { + mockJWT, err := jwt.NewBuilder().Subject("authorization-user").Build() + require.NoError(t, err) + + rawToken, err := jwt.Sign(mockJWT, jwt.WithInsecureNoSignature()) + require.NoError(t, err) + + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("Authorization", "Bearer "+string(rawToken))) + rehydratedCtx, err := RehydrateAccessTokenFromIncomingMetadata(ctx, nil) + require.NoError(t, err) + + retrievedJWT := GetAccessTokenFromContext(rehydratedCtx, nil) + require.NotNil(t, retrievedJWT) + assert.Equal(t, "authorization-user", retrievedJWT.Subject()) + }) + + t.Run("returns unchanged context when auth context already exists", func(t *testing.T) { + existingJWT, err := jwt.NewBuilder().Subject("existing-user").Build() + require.NoError(t, err) + + ctx := ContextWithAuthNInfo(t.Context(), nil, existingJWT, "existing-raw-token") + ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(AccessTokenKey, "different-token")) + rehydratedCtx, err := RehydrateAccessTokenFromIncomingMetadata(ctx, nil) + require.NoError(t, err) + + retrievedJWT := GetAccessTokenFromContext(rehydratedCtx, nil) + require.NotNil(t, retrievedJWT) + assert.Equal(t, "existing-user", retrievedJWT.Subject()) + assert.Equal(t, "existing-raw-token", GetRawAccessTokenFromContext(rehydratedCtx, nil)) + }) + + t.Run("returns error on invalid token metadata", func(t *testing.T) { + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs(AccessTokenKey, "not-a-jwt")) + _, err := RehydrateAccessTokenFromIncomingMetadata(ctx, nil) + require.Error(t, err) + }) } func TestGetContextDetailsInvalidType(t *testing.T) { From 18d472f18e98eb68560ef24621c69bbe8745b1e1 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 14:38:27 -0700 Subject: [PATCH 08/20] docs(auth): align IPC auth comments Signed-off-by: jakedoublev --- service/internal/auth/authn.go | 7 +++---- service/pkg/server/options.go | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index 99768fe0b1..2042ef1602 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -439,10 +439,9 @@ func IPCMetadataClientInterceptor(log *logger.Logger) connect.UnaryInterceptorFu } // IPCUnaryServerInterceptor is a grpc interceptor that: -// 1. verifies the token in the metadata -// 2. reauthorizes the token if the route is in the list -// 3. translates known IPC Connect request headers back to context metadata for downstream consumers -// 4. rehydrates auth context from incoming metadata for non-reauth IPC routes +// 1. translates known IPC Connect request headers back to incoming metadata +// 2. reauthorizes routes that are configured for IPC reauth +// 3. rehydrates auth context from propagated incoming metadata func (a Authentication) IPCUnaryServerInterceptor() connect.UnaryInterceptorFunc { interceptor := func(next connect.UnaryFunc) connect.UnaryFunc { return connect.UnaryFunc(func( diff --git a/service/pkg/server/options.go b/service/pkg/server/options.go index db3952ceef..89c25725e7 100644 --- a/service/pkg/server/options.go +++ b/service/pkg/server/options.go @@ -83,8 +83,8 @@ func WithPublicRoutes(routes []string) StartOptions { } } -// WithIPCReauthRoutes option sets the IPC reauthorization routes for the server. -// It enables the server to reauthorize IPC routes and embed the token on the context. +// WithIPCReauthRoutes sets the IPC routes that should be fully reauthorized +// instead of only rehydrating auth context from propagated metadata. func WithIPCReauthRoutes(routes []string) StartOptions { return func(c StartConfig) StartConfig { c.IPCReauthRoutes = routes From d42fb00fb3734efc383bfbb654d35372ce13487b Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 14:54:47 -0700 Subject: [PATCH 09/20] test(audit): simplify logger test helper Signed-off-by: jakedoublev --- service/logger/audit/logger_test.go | 57 ++++++++++++++++------------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/service/logger/audit/logger_test.go b/service/logger/audit/logger_test.go index 7ef83b1638..6031dcde58 100644 --- a/service/logger/audit/logger_test.go +++ b/service/logger/audit/logger_test.go @@ -94,11 +94,11 @@ func extractLogEntry(t *testing.T, logBuffer *bytes.Buffer) (logEntryStructure, return entry, entryTime } -func doWithLogger(t *testing.T, testFunc func(ctx context.Context, l *Logger)) (ls logEntryStructure, lt time.Time) { //nolint:nonamedreturns // Required to rewrite on panics, right? - return doWithLoggerContext(createTestContext(t), t, testFunc) -} - -func doWithLoggerContext(ctx context.Context, t *testing.T, testFunc func(ctx context.Context, l *Logger)) (ls logEntryStructure, lt time.Time) { //nolint:nonamedreturns // Required to rewrite on panics, right? +func doWithLogger(t *testing.T, contextSetup func(context.Context) context.Context, testFunc func(ctx context.Context, l *Logger)) (ls logEntryStructure, lt time.Time) { //nolint:nonamedreturns // Named returns let the deferred recover path populate the extracted audit log. + ctx := createTestContext(t) + if contextSetup != nil { + ctx = contextSetup(ctx) + } l, buf := createTestLogger() tx, ok := ctx.Value(contextKey{}).(*auditTransaction) require.True(t, ok, "audit transaction missing from context") @@ -146,7 +146,7 @@ func decodeAuditPayload(t *testing.T, payload json.RawMessage) map[string]any { } func TestAuditRewrapSuccess(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.RewrapSuccess(ctx, rewrapParams) }) @@ -205,7 +205,7 @@ func TestAuditRewrapSuccess(t *testing.T) { } func TestAuditRewrapFailure(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.RewrapFailure(ctx, rewrapParams) }) @@ -264,7 +264,7 @@ func TestAuditRewrapFailure(t *testing.T) { } func TestPolicyCRUDSuccess(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.PolicyCRUDSuccess(ctx, policyCRUDParams) }) @@ -314,7 +314,7 @@ func TestPolicyCRUDSuccess(t *testing.T) { } func TestPolicyCrudFailure(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.PolicyCRUDFailure(ctx, policyCRUDParams) }) @@ -365,9 +365,10 @@ func TestPolicyCrudFailure(t *testing.T) { func TestAuditJWTClaimMappingsApplyToPolicyAudit(t *testing.T) { token, rawToken := createTestJWTForAudit(t) - ctx := ctxAuth.ContextWithAuthNInfo(createTestContext(t), nil, token, rawToken) - logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { + logEntry, _ := doWithLogger(t, func(ctx context.Context) context.Context { + return ctxAuth.ContextWithAuthNInfo(ctx, nil, token, rawToken) + }, func(ctx context.Context, l *Logger) { require.NoError(t, l.ApplyConfig(Config{ JWTClaimMappings: []JWTClaimMapping{ {Claim: "sub", Path: "eventMetaData.requester.sub"}, @@ -390,11 +391,14 @@ func TestAuditJWTClaimMappingsApplyToPolicyAudit(t *testing.T) { func TestAuditJWTClaimMappingsUseRehydratedAuthContext(t *testing.T) { _, rawToken := createTestJWTForAudit(t) - ctx := metadata.NewIncomingContext(createTestContext(t), metadata.Pairs(ctxAuth.AccessTokenKey, rawToken)) - ctx, err := ctxAuth.RehydrateAccessTokenFromIncomingMetadata(ctx, nil) - require.NoError(t, err) - logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { + logEntry, _ := doWithLogger(t, func(ctx context.Context) context.Context { + ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(ctxAuth.AccessTokenKey, rawToken)) + var err error + ctx, err = ctxAuth.RehydrateAccessTokenFromIncomingMetadata(ctx, nil) + require.NoError(t, err) + return ctx + }, func(ctx context.Context, l *Logger) { require.NoError(t, l.ApplyConfig(Config{ JWTClaimMappings: []JWTClaimMapping{ {Claim: "sub", Path: "eventMetaData.requester.sub"}, @@ -413,9 +417,10 @@ func TestAuditJWTClaimMappingsUseRehydratedAuthContext(t *testing.T) { func TestAuditJWTClaimMappingsCanWriteToEntityMetadata(t *testing.T) { token, rawToken := createTestJWTForAudit(t) - ctx := ctxAuth.ContextWithAuthNInfo(createTestContext(t), nil, token, rawToken) - logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { + logEntry, _ := doWithLogger(t, func(ctx context.Context) context.Context { + return ctxAuth.ContextWithAuthNInfo(ctx, nil, token, rawToken) + }, func(ctx context.Context, l *Logger) { require.NoError(t, l.ApplyConfig(Config{ JWTClaimMappings: []JWTClaimMapping{ {Claim: "sub", Path: "eventMetaData.entityMetadata.sub"}, @@ -438,9 +443,10 @@ func TestAuditJWTClaimMappingsCanWriteToEntityMetadata(t *testing.T) { func TestAuditJWTClaimMappingsCoverNamedAndUnnamedPaths(t *testing.T) { token, rawToken := createTestJWTForAudit(t) - ctx := ctxAuth.ContextWithAuthNInfo(createTestContext(t), nil, token, rawToken) - logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { + logEntry, _ := doWithLogger(t, func(ctx context.Context) context.Context { + return ctxAuth.ContextWithAuthNInfo(ctx, nil, token, rawToken) + }, func(ctx context.Context, l *Logger) { require.NoError(t, l.ApplyConfig(Config{ JWTClaimMappings: []JWTClaimMapping{ {Claim: "sub", Path: "object.name"}, @@ -482,9 +488,10 @@ func TestAuditJWTClaimMappingsCoverNamedAndUnnamedPaths(t *testing.T) { func TestAuditJWTClaimMappingsLeaveReservedFieldsUntouched(t *testing.T) { token, rawToken := createTestJWTForAudit(t) - ctx := ctxAuth.ContextWithAuthNInfo(createTestContext(t), nil, token, rawToken) - logEntry, _ := doWithLoggerContext(ctx, t, func(ctx context.Context, l *Logger) { + logEntry, _ := doWithLogger(t, func(ctx context.Context) context.Context { + return ctxAuth.ContextWithAuthNInfo(ctx, nil, token, rawToken) + }, func(ctx context.Context, l *Logger) { require.NoError(t, l.ApplyConfig(Config{ JWTClaimMappings: []JWTClaimMapping{ {Claim: "sub", Path: "eventMetaData.requester.sub"}, @@ -518,7 +525,7 @@ func TestAuditApplyConfigRejectsReservedPaths(t *testing.T) { } func TestDeferredRewrapSuccess(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.RewrapSuccess(ctx, rewrapParams) }) @@ -577,7 +584,7 @@ func TestDeferredRewrapSuccess(t *testing.T) { } func TestDeferredRewrapCancelled(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.RewrapSuccess(ctx, rewrapParams) panic(errors.New("operation failed")) }) @@ -639,7 +646,7 @@ func TestDeferredRewrapCancelled(t *testing.T) { } func TestDeferredPolicyCRUDSuccess(t *testing.T) { - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.PolicyCRUDSuccess(ctx, policyCRUDParams) }) @@ -702,7 +709,7 @@ func TestGetDecision(t *testing.T) { FQNs: []string{"test-fqn"}, } - logEntry, logEntryTime := doWithLogger(t, func(ctx context.Context, l *Logger) { + logEntry, logEntryTime := doWithLogger(t, nil, func(ctx context.Context, l *Logger) { l.GetDecision(ctx, params) }) expectedAuditLog := fmt.Sprintf( From b3687d8e26bda0741d7d136b27df27653ba1063d Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 14:58:42 -0700 Subject: [PATCH 10/20] refactor(audit): export validation errors Signed-off-by: jakedoublev --- service/logger/audit/logger_test.go | 2 +- service/logger/audit/schema.go | 19 +++++++++++-------- service/logger/audit/schema_test.go | 12 ++++++------ service/pkg/config/config_test.go | 8 ++++++-- 4 files changed, 24 insertions(+), 17 deletions(-) diff --git a/service/logger/audit/logger_test.go b/service/logger/audit/logger_test.go index 6031dcde58..e8ffdf91e2 100644 --- a/service/logger/audit/logger_test.go +++ b/service/logger/audit/logger_test.go @@ -520,8 +520,8 @@ func TestAuditApplyConfigRejectsReservedPaths(t *testing.T) { }) require.Error(t, err) + require.ErrorIs(t, err, ErrReservedAuditPath) require.ErrorContains(t, err, "jwt_claim_mappings[0].path") - require.ErrorContains(t, err, "reserved audit path") } func TestDeferredRewrapSuccess(t *testing.T) { diff --git a/service/logger/audit/schema.go b/service/logger/audit/schema.go index ade402d32c..441e3fab08 100644 --- a/service/logger/audit/schema.go +++ b/service/logger/audit/schema.go @@ -8,9 +8,12 @@ import ( ) var ( - errReservedAuditPath = errors.New("reserved audit path") - errUnknownAuditPath = errors.New("unknown audit path") - errAuditContainerPath = errors.New("audit path resolves to a container") + // ErrReservedAuditPath indicates a claim destination targets a protected audit field. + ErrReservedAuditPath = errors.New("reserved audit path") + // ErrUnknownAuditPath indicates a claim destination traverses an unknown closed-schema path. + ErrUnknownAuditPath = errors.New("unknown audit path") + // ErrAuditContainerPath indicates a claim destination resolves to a container instead of a writable leaf. + ErrAuditContainerPath = errors.New("audit path resolves to a container") auditClaimDestinationSchema = mustBuildAuditPathSchema(reflect.TypeOf(EventObject{})) ) @@ -155,14 +158,14 @@ func indirectType(t reflect.Type) reflect.Type { func validateClaimDestinationPath(path string) error { if path == "" { - return fmt.Errorf("%w: empty path", errUnknownAuditPath) + return fmt.Errorf("%w: empty path", ErrUnknownAuditPath) } segments := strings.Split(path, ".") current := auditClaimDestinationSchema for idx, segment := range segments { if segment == "" { - return fmt.Errorf("%w: %s", errUnknownAuditPath, path) + return fmt.Errorf("%w: %s", ErrUnknownAuditPath, path) } child, ok := current.children[segment] @@ -170,16 +173,16 @@ func validateClaimDestinationPath(path string) error { if current.extensible { return nil } - return fmt.Errorf("%w: %s", errUnknownAuditPath, path) + return fmt.Errorf("%w: %s", ErrUnknownAuditPath, path) } isLast := idx == len(segments)-1 if isLast { switch { case child.reserved: - return fmt.Errorf("%w: %s", errReservedAuditPath, path) + return fmt.Errorf("%w: %s", ErrReservedAuditPath, path) case child.extensible || len(child.children) > 0: - return fmt.Errorf("%w: %s", errAuditContainerPath, path) + return fmt.Errorf("%w: %s", ErrAuditContainerPath, path) default: return nil } diff --git a/service/logger/audit/schema_test.go b/service/logger/audit/schema_test.go index 9e58bbbc16..8839d786e0 100644 --- a/service/logger/audit/schema_test.go +++ b/service/logger/audit/schema_test.go @@ -24,27 +24,27 @@ func TestValidateClaimDestinationPath(t *testing.T) { t.Run("rejects reserved paths", func(t *testing.T) { err := validateClaimDestinationPath("requestID") - require.ErrorIs(t, err, errReservedAuditPath) + require.ErrorIs(t, err, ErrReservedAuditPath) err = validateClaimDestinationPath("action.result") - require.ErrorIs(t, err, errReservedAuditPath) + require.ErrorIs(t, err, ErrReservedAuditPath) }) t.Run("rejects container paths", func(t *testing.T) { err := validateClaimDestinationPath("eventMetaData") - require.ErrorIs(t, err, errAuditContainerPath) + require.ErrorIs(t, err, ErrAuditContainerPath) err = validateClaimDestinationPath("object") - require.ErrorIs(t, err, errAuditContainerPath) + require.ErrorIs(t, err, ErrAuditContainerPath) }) t.Run("rejects unknown nested paths below closed containers", func(t *testing.T) { err := validateClaimDestinationPath("object.extra.foo") - require.ErrorIs(t, err, errUnknownAuditPath) + require.ErrorIs(t, err, ErrUnknownAuditPath) }) t.Run("rejects leading dot paths", func(t *testing.T) { err := validateClaimDestinationPath(".banana") - require.ErrorIs(t, err, errUnknownAuditPath) + require.ErrorIs(t, err, ErrUnknownAuditPath) }) } diff --git a/service/pkg/config/config_test.go b/service/pkg/config/config_test.go index 129918217d..8b1a91917f 100644 --- a/service/pkg/config/config_test.go +++ b/service/pkg/config/config_test.go @@ -357,6 +357,7 @@ func TestLoad_Precedence(t *testing.T) { setupLoaders func(t *testing.T, configFile string) []Loader envVars map[string]string err error + causeErr error errContains string fileContent string asserts func(t *testing.T, cfg *Config) @@ -465,8 +466,8 @@ special_key: require.NoError(t, err) return []Loader{file, defaults} }, - err: ErrUnmarshallingConfig, - errContains: "reserved audit path", + err: ErrUnmarshallingConfig, + causeErr: auditcfg.ErrReservedAuditPath, fileContent: ` audit: jwt_claim_mappings: @@ -711,6 +712,9 @@ logger: if tc.err != nil { require.Error(t, err) require.ErrorIs(t, err, tc.err) + if tc.causeErr != nil { + require.ErrorIs(t, err, tc.causeErr) + } if tc.errContains != "" { require.ErrorContains(t, err, tc.errContains) } From 138f9f485ee9a7521e729821a24b46c76cc3b11e Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 15:11:13 -0700 Subject: [PATCH 11/20] chore(config): remove empty kas stanzas Signed-off-by: jakedoublev --- opentdf-core-mode.yaml | 2 -- opentdf-ers-mode.yaml | 1 - 2 files changed, 3 deletions(-) diff --git a/opentdf-core-mode.yaml b/opentdf-core-mode.yaml index a3177d4a63..4350cbb70a 100644 --- a/opentdf-core-mode.yaml +++ b/opentdf-core-mode.yaml @@ -19,8 +19,6 @@ audit: # path: eventMetaData.requester.sub # - claim: realm_access.roles # path: eventMetaData.requester.roles -services: - kas: # DB and Server configurations are defaulted for local development # db: # host: localhost diff --git a/opentdf-ers-mode.yaml b/opentdf-ers-mode.yaml index aa38c8d23e..95186c759b 100644 --- a/opentdf-ers-mode.yaml +++ b/opentdf-ers-mode.yaml @@ -14,7 +14,6 @@ audit: # - claim: realm_access.roles # path: eventMetaData.requester.roles services: - kas: entityresolution: log_level: info url: http://localhost:8888/auth From 560ce14603af045a3d228586f82e008f73047658 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 15:29:49 -0700 Subject: [PATCH 12/20] fix(audit): tighten JWT claim transport Signed-off-by: jakedoublev --- service/internal/auth/authn.go | 49 ++++++++++++++++++- service/logger/audit/logger_test.go | 46 +++++++----------- service/logger/audit/schema_test.go | 7 ++- service/logger/audit/utils.go | 4 +- service/pkg/auth/context_auth.go | 69 ++------------------------- service/pkg/auth/context_auth_test.go | 59 ----------------------- 6 files changed, 75 insertions(+), 159 deletions(-) diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index 2042ef1602..c1266288ee 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -438,10 +438,55 @@ func IPCMetadataClientInterceptor(log *logger.Logger) connect.UnaryInterceptorFu }) } +// rehydrateIPCAuthContext reconstructs pkg/auth context from propagated IPC metadata. +// It only transports an already-authenticated token across an internal hop; it does not +// validate the token again. Network-facing requests must continue to rely on the +// ConnectUnaryServerInterceptor auth middleware for validation. +func rehydrateIPCAuthContext(ctx context.Context, l *logger.Logger) (context.Context, error) { + if ctxAuth.GetAccessTokenFromContext(ctx, l) != nil && ctxAuth.GetRawAccessTokenFromContext(ctx, l) != "" { + return ctx, nil + } + + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return ctx, nil + } + + rawToken := rawAccessTokenFromIncomingMetadata(md) + if rawToken == "" { + return ctx, nil + } + + parsed, err := jwt.Parse([]byte(rawToken), jwt.WithVerify(false), jwt.WithValidate(false)) + if err != nil { + if l != nil { + l.ErrorContext(ctx, "failed to rehydrate IPC access token from metadata", slog.Any("error", err)) + } + return ctx, fmt.Errorf("rehydrate IPC access token from metadata: %w", err) + } + + return ctxAuth.ContextWithAuthNInfo(ctx, nil, parsed, rawToken), nil +} + +func rawAccessTokenFromIncomingMetadata(md metadata.MD) string { + if accessTokens := md.Get(ctxAuth.AccessTokenKey); len(accessTokens) > 0 && accessTokens[0] != "" { + return accessTokens[0] + } + + authHeaders := md.Get("authorization") + if len(authHeaders) == 0 { + authHeaders = md.Get("Authorization") + } + if len(authHeaders) == 0 || authHeaders[0] == "" { + return "" + } + return strings.TrimPrefix(strings.TrimPrefix(authHeaders[0], "Bearer "), "DPoP ") +} + // IPCUnaryServerInterceptor is a grpc interceptor that: // 1. translates known IPC Connect request headers back to incoming metadata // 2. reauthorizes routes that are configured for IPC reauth -// 3. rehydrates auth context from propagated incoming metadata +// 3. rehydrates auth context from propagated incoming metadata without revalidating it func (a Authentication) IPCUnaryServerInterceptor() connect.UnaryInterceptorFunc { interceptor := func(next connect.UnaryFunc) connect.UnaryFunc { return connect.UnaryFunc(func( @@ -467,7 +512,7 @@ func (a Authentication) IPCUnaryServerInterceptor() connect.UnaryInterceptorFunc if err != nil { return nil, err } - nextCtx, err = ctxAuth.RehydrateAccessTokenFromIncomingMetadata(nextCtx, a.logger) + nextCtx, err = rehydrateIPCAuthContext(nextCtx, a.logger) if err != nil { return nil, connect.NewError(connect.CodeInternal, errors.New("failed to rehydrate IPC authentication context")) } diff --git a/service/logger/audit/logger_test.go b/service/logger/audit/logger_test.go index e8ffdf91e2..73dd498ab4 100644 --- a/service/logger/audit/logger_test.go +++ b/service/logger/audit/logger_test.go @@ -16,7 +16,6 @@ import ( ctxAuth "github.com/opentdf/platform/service/pkg/auth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/grpc/metadata" ) // Params @@ -389,32 +388,6 @@ func TestAuditJWTClaimMappingsApplyToPolicyAudit(t *testing.T) { assert.Equal(t, true, requester["emailVerified"]) } -func TestAuditJWTClaimMappingsUseRehydratedAuthContext(t *testing.T) { - _, rawToken := createTestJWTForAudit(t) - - logEntry, _ := doWithLogger(t, func(ctx context.Context) context.Context { - ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(ctxAuth.AccessTokenKey, rawToken)) - var err error - ctx, err = ctxAuth.RehydrateAccessTokenFromIncomingMetadata(ctx, nil) - require.NoError(t, err) - return ctx - }, func(ctx context.Context, l *Logger) { - require.NoError(t, l.ApplyConfig(Config{ - JWTClaimMappings: []JWTClaimMapping{ - {Claim: "sub", Path: "eventMetaData.requester.sub"}, - }, - })) - l.PolicyCRUDSuccess(ctx, policyCRUDParams) - }) - - payload := decodeAuditPayload(t, logEntry.Audit) - eventMetaData, ok := payload["eventMetaData"].(map[string]any) - require.True(t, ok) - requester, ok := eventMetaData["requester"].(map[string]any) - require.True(t, ok) - assert.Equal(t, "jwt-user", requester["sub"]) -} - func TestAuditJWTClaimMappingsCanWriteToEntityMetadata(t *testing.T) { token, rawToken := createTestJWTForAudit(t) @@ -511,11 +484,26 @@ func TestAuditJWTClaimMappingsLeaveReservedFieldsUntouched(t *testing.T) { } func TestAuditApplyConfigRejectsReservedPaths(t *testing.T) { - l, _ := createTestLogger() + t.Run("requestID", func(t *testing.T) { + assertReservedAuditPathRejected(t, "requestID") + }) + + t.Run("clientInfo.userAgent", func(t *testing.T) { + assertReservedAuditPathRejected(t, "clientInfo.userAgent") + }) + + t.Run("clientInfo.requestIP", func(t *testing.T) { + assertReservedAuditPathRejected(t, "clientInfo.requestIP") + }) +} +func assertReservedAuditPathRejected(t *testing.T, path string) { + t.Helper() + + l, _ := createTestLogger() err := l.ApplyConfig(Config{ JWTClaimMappings: []JWTClaimMapping{ - {Claim: "sub", Path: "requestID"}, + {Claim: "sub", Path: path}, }, }) diff --git a/service/logger/audit/schema_test.go b/service/logger/audit/schema_test.go index 8839d786e0..878b2fdad3 100644 --- a/service/logger/audit/schema_test.go +++ b/service/logger/audit/schema_test.go @@ -9,7 +9,6 @@ import ( func TestValidateClaimDestinationPath(t *testing.T) { t.Run("allows writable leaf paths", func(t *testing.T) { require.NoError(t, validateClaimDestinationPath("object.id")) - require.NoError(t, validateClaimDestinationPath("clientInfo.requestIP")) }) t.Run("allows nested paths below extensible maps", func(t *testing.T) { @@ -28,6 +27,12 @@ func TestValidateClaimDestinationPath(t *testing.T) { err = validateClaimDestinationPath("action.result") require.ErrorIs(t, err, ErrReservedAuditPath) + + err = validateClaimDestinationPath("clientInfo.userAgent") + require.ErrorIs(t, err, ErrReservedAuditPath) + + err = validateClaimDestinationPath("clientInfo.requestIP") + require.ErrorIs(t, err, ErrReservedAuditPath) }) t.Run("rejects container paths", func(t *testing.T) { diff --git a/service/logger/audit/utils.go b/service/logger/audit/utils.go index a65a7adac6..79d06eb433 100644 --- a/service/logger/audit/utils.go +++ b/service/logger/audit/utils.go @@ -130,9 +130,9 @@ func (e auditEventActor) LogValue() slog.Value { // event.clientInfo type eventClientInfo struct { - UserAgent string `json:"userAgent" audit:"userAgent"` + UserAgent string `json:"userAgent" audit:"userAgent,reserved"` Platform string `json:"platform" audit:"platform,reserved"` - RequestIP string `json:"requestIp" audit:"requestIP"` + RequestIP string `json:"requestIp" audit:"requestIP,reserved"` } func (e eventClientInfo) LogValue() slog.Value { diff --git a/service/pkg/auth/context_auth.go b/service/pkg/auth/context_auth.go index d1ac292bc6..e6a55279f6 100644 --- a/service/pkg/auth/context_auth.go +++ b/service/pkg/auth/context_auth.go @@ -3,8 +3,6 @@ package auth import ( "context" "errors" - "fmt" - "strings" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" @@ -85,33 +83,11 @@ func GetRawAccessTokenFromContext(ctx context.Context, l optionalErrorLogger) st return "" } -// RehydrateAccessTokenFromIncomingMetadata reconstructs auth context from incoming metadata -// so downstream code can use the normal accessors without transport-specific fallbacks. -func RehydrateAccessTokenFromIncomingMetadata(ctx context.Context, l optionalErrorLogger) (context.Context, error) { - if c := getContextDetails(ctx, l); c != nil && c.accessToken != nil && c.rawToken != "" { - return ctx, nil - } - - rawToken := getRawAccessTokenFromMetadata(ctx, true) - if rawToken == "" { - return ctx, nil - } - - parsed, err := jwt.Parse([]byte(rawToken), jwt.WithVerify(false), jwt.WithValidate(false)) - if err != nil { - if l != nil { - l.ErrorContext(ctx, "failed to rehydrate access token from incoming metadata", "error", err) - } - return ctx, fmt.Errorf("rehydrate access token from incoming metadata: %w", err) - } - - return ContextWithAuthNInfo(ctx, nil, parsed, rawToken), nil -} - // EnrichIncomingContextMetadataWithAuthn adds the access token and client ID to incoming context metadata // -// Adding the authn info to gRPC metadata propagates it across services rather than strictly -// in-process within Go alone +// This transports an already-authenticated token across internal hops. It does not +// validate the token again; the ConnectUnaryServerInterceptor auth middleware does that +// before ContextWithAuthNInfo is created. func EnrichIncomingContextMetadataWithAuthn(ctx context.Context, l optionalErrorLogger, clientID string) context.Context { rawToken := GetRawAccessTokenFromContext(ctx, l) @@ -132,45 +108,6 @@ func EnrichIncomingContextMetadataWithAuthn(ctx context.Context, l optionalError return metadata.NewIncomingContext(ctx, md) } -func getRawAccessTokenFromMetadata(ctx context.Context, incoming bool) string { - var ( - md metadata.MD - ok bool - ) - if incoming { - md, ok = metadata.FromIncomingContext(ctx) - } else { - md, ok = metadata.FromOutgoingContext(ctx) - } - if !ok { - return "" - } - - if accessTokens := md.Get(AccessTokenKey); len(accessTokens) > 0 && accessTokens[0] != "" { - return accessTokens[0] - } - - authHeaders := md.Get("authorization") - if len(authHeaders) == 0 { - authHeaders = md.Get("Authorization") - } - if len(authHeaders) > 0 { - return trimAuthorizationHeader(authHeaders[0]) - } - return "" -} - -func trimAuthorizationHeader(header string) string { - switch { - case strings.HasPrefix(header, "Bearer "): - return strings.TrimPrefix(header, "Bearer ") - case strings.HasPrefix(header, "DPoP "): - return strings.TrimPrefix(header, "DPoP ") - default: - return header - } -} - // GetClientIDFromContext retrieves the client ID from the metadata in the context func GetClientIDFromContext(ctx context.Context, incoming bool) (string, error) { var ( diff --git a/service/pkg/auth/context_auth_test.go b/service/pkg/auth/context_auth_test.go index f88cc4b2fd..e427c5eb72 100644 --- a/service/pkg/auth/context_auth_test.go +++ b/service/pkg/auth/context_auth_test.go @@ -92,65 +92,6 @@ func TestGetAccessTokenFromContextDoesNotFallbackToMetadata(t *testing.T) { assert.Nil(t, retrievedJWT) } -func TestRehydrateAccessTokenFromIncomingMetadata(t *testing.T) { - t.Run("rehydrates from access token metadata", func(t *testing.T) { - mockJWT, err := jwt.NewBuilder(). - Subject("metadata-user"). - Claim("roles", []string{"admin"}). - Build() - require.NoError(t, err) - - rawToken, err := jwt.Sign(mockJWT, jwt.WithInsecureNoSignature()) - require.NoError(t, err) - - ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs(AccessTokenKey, string(rawToken))) - rehydratedCtx, err := RehydrateAccessTokenFromIncomingMetadata(ctx, nil) - require.NoError(t, err) - - retrievedJWT := GetAccessTokenFromContext(rehydratedCtx, nil) - require.NotNil(t, retrievedJWT) - assert.Equal(t, "metadata-user", retrievedJWT.Subject()) - assert.Equal(t, string(rawToken), GetRawAccessTokenFromContext(rehydratedCtx, nil)) - }) - - t.Run("uses authorization metadata", func(t *testing.T) { - mockJWT, err := jwt.NewBuilder().Subject("authorization-user").Build() - require.NoError(t, err) - - rawToken, err := jwt.Sign(mockJWT, jwt.WithInsecureNoSignature()) - require.NoError(t, err) - - ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs("Authorization", "Bearer "+string(rawToken))) - rehydratedCtx, err := RehydrateAccessTokenFromIncomingMetadata(ctx, nil) - require.NoError(t, err) - - retrievedJWT := GetAccessTokenFromContext(rehydratedCtx, nil) - require.NotNil(t, retrievedJWT) - assert.Equal(t, "authorization-user", retrievedJWT.Subject()) - }) - - t.Run("returns unchanged context when auth context already exists", func(t *testing.T) { - existingJWT, err := jwt.NewBuilder().Subject("existing-user").Build() - require.NoError(t, err) - - ctx := ContextWithAuthNInfo(t.Context(), nil, existingJWT, "existing-raw-token") - ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(AccessTokenKey, "different-token")) - rehydratedCtx, err := RehydrateAccessTokenFromIncomingMetadata(ctx, nil) - require.NoError(t, err) - - retrievedJWT := GetAccessTokenFromContext(rehydratedCtx, nil) - require.NotNil(t, retrievedJWT) - assert.Equal(t, "existing-user", retrievedJWT.Subject()) - assert.Equal(t, "existing-raw-token", GetRawAccessTokenFromContext(rehydratedCtx, nil)) - }) - - t.Run("returns error on invalid token metadata", func(t *testing.T) { - ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs(AccessTokenKey, "not-a-jwt")) - _, err := RehydrateAccessTokenFromIncomingMetadata(ctx, nil) - require.Error(t, err) - }) -} - func TestGetContextDetailsInvalidType(t *testing.T) { // Create a context with an invalid type ctx := context.WithValue(t.Context(), authnContextKey, "invalidType") From ca170c2f3fd702ca1faa53dbd5a583e82a65086a Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 15:40:16 -0700 Subject: [PATCH 13/20] fix(audit): address review feedback Signed-off-by: jakedoublev --- service/internal/auth/authn.go | 2 +- .../authn_ipc_metadata_interceptor_test.go | 22 +++++++++++ service/internal/dotnotation/dotnotation.go | 8 ++++ .../internal/dotnotation/dotnotation_test.go | 39 +++++++++++++++++++ service/logger/audit/logger.go | 10 ++++- service/logger/audit/logger_test.go | 29 ++++++++++++++ service/logger/audit/schema.go | 9 ++++- service/logger/audit/schema_test.go | 4 ++ service/pkg/server/services.go | 14 +++---- service/pkg/server/services_test.go | 14 +++++++ 10 files changed, 140 insertions(+), 11 deletions(-) diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index c1266288ee..096ee2e2aa 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -465,7 +465,7 @@ func rehydrateIPCAuthContext(ctx context.Context, l *logger.Logger) (context.Con return ctx, fmt.Errorf("rehydrate IPC access token from metadata: %w", err) } - return ctxAuth.ContextWithAuthNInfo(ctx, nil, parsed, rawToken), nil + return ctxAuth.ContextWithAuthNInfo(ctx, ctxAuth.GetJWKFromContext(ctx, l), parsed, rawToken), nil } func rawAccessTokenFromIncomingMetadata(md metadata.MD) string { diff --git a/service/internal/auth/authn_ipc_metadata_interceptor_test.go b/service/internal/auth/authn_ipc_metadata_interceptor_test.go index 0ac1603835..49238d67ad 100644 --- a/service/internal/auth/authn_ipc_metadata_interceptor_test.go +++ b/service/internal/auth/authn_ipc_metadata_interceptor_test.go @@ -353,3 +353,25 @@ func TestIPCUnaryServerInterceptor_Integration(t *testing.T) { assert.Equal(t, connect.CodeInternal, connectErr.Code()) }) } + +func TestRehydrateIPCAuthContextPreservesExistingJWK(t *testing.T) { + testLogger := logger.CreateTestLogger() + mockJWT, err := jwt.NewBuilder().Subject("rehydrated-user").Build() + require.NoError(t, err) + rawToken, err := jwt.Sign(mockJWT, jwt.WithInsecureNoSignature()) + require.NoError(t, err) + mockJWK, err := jwk.FromRaw([]byte("existing-jwk")) + require.NoError(t, err) + + ctx := metadata.NewIncomingContext(t.Context(), metadata.Pairs(ctxAuth.AccessTokenKey, string(rawToken))) + ctx = ctxAuth.ContextWithAuthNInfo(ctx, mockJWK, nil, "") + + rehydratedCtx, err := rehydrateIPCAuthContext(ctx, testLogger) + require.NoError(t, err) + require.Same(t, mockJWK, ctxAuth.GetJWKFromContext(rehydratedCtx, testLogger)) + + retrievedJWT := ctxAuth.GetAccessTokenFromContext(rehydratedCtx, testLogger) + require.NotNil(t, retrievedJWT) + assert.Equal(t, "rehydrated-user", retrievedJWT.Subject()) + assert.Equal(t, string(rawToken), ctxAuth.GetRawAccessTokenFromContext(rehydratedCtx, testLogger)) +} diff --git a/service/internal/dotnotation/dotnotation.go b/service/internal/dotnotation/dotnotation.go index 7a021fd834..0992f60dd6 100644 --- a/service/internal/dotnotation/dotnotation.go +++ b/service/internal/dotnotation/dotnotation.go @@ -29,11 +29,19 @@ func Get(m map[string]any, key string) any { // Set stores a value in a nested map using dot notation keys, creating // intermediate maps as needed. func Set(m map[string]any, key string, value any) error { + if m == nil { + return errors.New("nil root map") + } if key == "" { return errors.New("empty path") } keys := strings.Split(key, ".") + for _, segment := range keys { + if segment == "" { + return fmt.Errorf("invalid path %q: empty segment", key) + } + } current := m for i, k := range keys[:len(keys)-1] { next, exists := current[k] diff --git a/service/internal/dotnotation/dotnotation_test.go b/service/internal/dotnotation/dotnotation_test.go index 824abf4320..9a9ba5cd59 100644 --- a/service/internal/dotnotation/dotnotation_test.go +++ b/service/internal/dotnotation/dotnotation_test.go @@ -70,4 +70,43 @@ func TestSet(t *testing.T) { t.Fatal("expected collision error") } }) + + t.Run("fails on nil root map", func(t *testing.T) { + defer func() { + if recovered := recover(); recovered != nil { + t.Fatalf("Set should not panic, got %v", recovered) + } + }() + + if err := Set(nil, "a.b", "value"); err == nil { + t.Fatal("expected error for nil root map") + } + }) + + t.Run("fails on malformed paths", func(t *testing.T) { + for _, path := range []string{"a..b", ".a", "a."} { + t.Run(path, func(t *testing.T) { + defer func() { + if recovered := recover(); recovered != nil { + t.Fatalf("Set should not panic, got %v", recovered) + } + }() + + input := make(map[string]any) + if err := Set(input, path, "value"); err == nil { + t.Fatal("expected malformed path error") + } + + if got := Get(input, "a"); got != nil { + t.Fatalf("expected no root value for a, got %v", got) + } + if got := Get(input, "a."); got != nil { + t.Fatalf("expected no nested invalid value, got %v", got) + } + if _, exists := input[""]; exists { + t.Fatal("expected no empty-string root key") + } + }) + } + }) } diff --git a/service/logger/audit/logger.go b/service/logger/audit/logger.go index 538882923e..5ca12509b9 100644 --- a/service/logger/audit/logger.go +++ b/service/logger/audit/logger.go @@ -62,12 +62,18 @@ func CreateAuditLogger(logger slog.Logger) *Logger { } } +func cloneConfig(cfg Config) Config { + cloned := cfg + cloned.JWTClaimMappings = append([]JWTClaimMapping(nil), cfg.JWTClaimMappings...) + return cloned +} + // ApplyConfig validates and stores the latest audit enrichment configuration. func (a *Logger) ApplyConfig(cfg Config) error { if err := cfg.Validate(); err != nil { return err } - a.config = cfg + a.config = cloneConfig(cfg) return nil } @@ -75,7 +81,7 @@ func (a *Logger) With(key string, value string) *Logger { return &Logger{ //nolint:sloglint // custom logger should support key/value pairs in With attributes logger: a.logger.With(key, value), - config: a.config, + config: cloneConfig(a.config), } } diff --git a/service/logger/audit/logger_test.go b/service/logger/audit/logger_test.go index 73dd498ab4..c2000928c2 100644 --- a/service/logger/audit/logger_test.go +++ b/service/logger/audit/logger_test.go @@ -497,6 +497,35 @@ func TestAuditApplyConfigRejectsReservedPaths(t *testing.T) { }) } +func TestAuditApplyConfigClonesMappings(t *testing.T) { + l, _ := createTestLogger() + cfg := Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.requester.sub"}, + }, + } + + require.NoError(t, l.ApplyConfig(cfg)) + + cfg.JWTClaimMappings[0].Path = "eventMetaData.requester.changed" + + require.Equal(t, "eventMetaData.requester.sub", l.config.JWTClaimMappings[0].Path) +} + +func TestAuditLoggerWithClonesMappings(t *testing.T) { + l, _ := createTestLogger() + require.NoError(t, l.ApplyConfig(Config{ + JWTClaimMappings: []JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.requester.sub"}, + }, + })) + + child := l.With("namespace", "policy") + l.config.JWTClaimMappings[0].Path = "eventMetaData.requester.changed" + + require.Equal(t, "eventMetaData.requester.sub", child.config.JWTClaimMappings[0].Path) +} + func assertReservedAuditPathRejected(t *testing.T, path string) { t.Helper() diff --git a/service/logger/audit/schema.go b/service/logger/audit/schema.go index 441e3fab08..e68b743322 100644 --- a/service/logger/audit/schema.go +++ b/service/logger/audit/schema.go @@ -26,6 +26,7 @@ type auditFieldOptions struct { type auditPathSchema struct { children map[string]*auditPathSchema + isLeaf bool reserved bool extensible bool } @@ -71,6 +72,7 @@ func addAuditSchemaFields(parent *auditPathSchema, structType reflect.Type) erro child := &auditPathSchema{ children: make(map[string]*auditPathSchema), + isLeaf: isWritableAuditLeaf(indirectType(field.Type)), reserved: opts.reserved, extensible: opts.extensible, } @@ -156,6 +158,11 @@ func indirectType(t reflect.Type) reflect.Type { return t } +func isWritableAuditLeaf(t reflect.Type) bool { + kind := indirectType(t).Kind() + return kind != reflect.Struct && kind != reflect.Map +} + func validateClaimDestinationPath(path string) error { if path == "" { return fmt.Errorf("%w: empty path", ErrUnknownAuditPath) @@ -181,7 +188,7 @@ func validateClaimDestinationPath(path string) error { switch { case child.reserved: return fmt.Errorf("%w: %s", ErrReservedAuditPath, path) - case child.extensible || len(child.children) > 0: + case child.extensible || !child.isLeaf: return fmt.Errorf("%w: %s", ErrAuditContainerPath, path) default: return nil diff --git a/service/logger/audit/schema_test.go b/service/logger/audit/schema_test.go index 878b2fdad3..e326af8475 100644 --- a/service/logger/audit/schema_test.go +++ b/service/logger/audit/schema_test.go @@ -9,6 +9,7 @@ import ( func TestValidateClaimDestinationPath(t *testing.T) { t.Run("allows writable leaf paths", func(t *testing.T) { require.NoError(t, validateClaimDestinationPath("object.id")) + require.NoError(t, validateClaimDestinationPath("actor.attributes")) }) t.Run("allows nested paths below extensible maps", func(t *testing.T) { @@ -41,6 +42,9 @@ func TestValidateClaimDestinationPath(t *testing.T) { err = validateClaimDestinationPath("object") require.ErrorIs(t, err, ErrAuditContainerPath) + + err = validateClaimDestinationPath("object.attributes") + require.ErrorIs(t, err, ErrAuditContainerPath) }) t.Run("rejects unknown nested paths below closed containers", func(t *testing.T) { diff --git a/service/pkg/server/services.go b/service/pkg/server/services.go index 2ae39f1665..20e022ba5f 100644 --- a/service/pkg/server/services.go +++ b/service/pkg/server/services.go @@ -290,15 +290,15 @@ func buildNamespaceLogger(baseLogger *logging.Logger, cfg *config.Config, ns, le newLoggerConfig := cfg.Logger newLoggerConfig.Level = level - if namespaceLogger, loggerErr := logging.NewLogger(newLoggerConfig); loggerErr == nil { - if err := namespaceLogger.Audit.ApplyConfig(cfg.Audit); err != nil { - return nil, fmt.Errorf("could not apply audit config for namespace %s: %w", ns, err) - } - return namespaceLogger.With("namespace", ns), nil + namespaceLogger, loggerErr := logging.NewLogger(newLoggerConfig) + if loggerErr != nil { + return nil, fmt.Errorf("invalid namespace logger config for %s: %w", ns, loggerErr) } - // Keep the existing logger if the override could not be created. - return baseLogger, nil + if err := namespaceLogger.Audit.ApplyConfig(cfg.Audit); err != nil { + return nil, fmt.Errorf("could not apply audit config for namespace %s: %w", ns, err) + } + return namespaceLogger.With("namespace", ns), nil } // newServiceDBClient creates a new database client for the specified namespace. diff --git a/service/pkg/server/services_test.go b/service/pkg/server/services_test.go index 0fc7130161..e08645cb5f 100644 --- a/service/pkg/server/services_test.go +++ b/service/pkg/server/services_test.go @@ -220,6 +220,20 @@ func (suite *ServiceTestSuite) Test_RegisterServices_In_Mode_Core_Plus_Kas_Expec suite.Equal(serviceregistry.ModeERS.String(), ers.Mode) } +func (suite *ServiceTestSuite) TestBuildNamespaceLoggerRejectsInvalidOverrideLevel() { + baseLogger, err := logger.NewLogger(logger.Config{Output: "stdout", Level: "info", Type: "json"}) + suite.Require().NoError(err) + + cfg := &config.Config{ + Logger: logger.Config{Output: "stdout", Level: "info", Type: "json"}, + } + + namespaceLogger, err := buildNamespaceLogger(baseLogger, cfg, "policy", "not-a-level") + suite.Require().Error(err) + suite.Nil(namespaceLogger) + suite.ErrorContains(err, "invalid namespace logger config for policy") +} + func (suite *ServiceTestSuite) TestStartServicesWithVariousCases() { ctx := context.Background() From ba0a3091a09d6a8cb5416e6226a499251cb16d53 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 15:46:07 -0700 Subject: [PATCH 14/20] test(server): use standard test logger helper Signed-off-by: jakedoublev --- service/pkg/server/services_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/service/pkg/server/services_test.go b/service/pkg/server/services_test.go index e08645cb5f..b33e72de20 100644 --- a/service/pkg/server/services_test.go +++ b/service/pkg/server/services_test.go @@ -221,8 +221,7 @@ func (suite *ServiceTestSuite) Test_RegisterServices_In_Mode_Core_Plus_Kas_Expec } func (suite *ServiceTestSuite) TestBuildNamespaceLoggerRejectsInvalidOverrideLevel() { - baseLogger, err := logger.NewLogger(logger.Config{Output: "stdout", Level: "info", Type: "json"}) - suite.Require().NoError(err) + baseLogger := logger.CreateTestLogger() cfg := &config.Config{ Logger: logger.Config{Output: "stdout", Level: "info", Type: "json"}, From 0811a3c876f8e7ed2fefb7caf45b73cbee95b450 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Fri, 1 May 2026 16:19:38 -0700 Subject: [PATCH 15/20] refactor(audit): derive payload shape from tags Signed-off-by: jakedoublev --- service/logger/audit/enrichment.go | 44 ++++++++++++++- service/logger/audit/schema.go | 35 ++++-------- service/logger/audit/utils.go | 90 ++++++++++-------------------- service/logger/logger.go | 2 +- 4 files changed, 84 insertions(+), 87 deletions(-) diff --git a/service/logger/audit/enrichment.go b/service/logger/audit/enrichment.go index 90eabd6b73..3179d7e845 100644 --- a/service/logger/audit/enrichment.go +++ b/service/logger/audit/enrichment.go @@ -2,6 +2,8 @@ package audit import ( "context" + "encoding" + "encoding/json" "log/slog" "reflect" @@ -10,7 +12,7 @@ import ( ) func (a *Logger) buildLogEntry(ctx context.Context, event *EventObject) map[string]any { - entry := event.logMap() + entry := event.emittedPayloadMap() a.applyJWTClaimEnrichment(ctx, entry) return entry } @@ -80,9 +82,49 @@ func normalizeAuditValue(value any) any { return normalized } + if marshaler, ok := value.(json.Marshaler); ok { + encoded, err := marshaler.MarshalJSON() + if err == nil { + var decoded any + if err := json.Unmarshal(encoded, &decoded); err == nil { + return decoded + } + } + } + + if marshaler, ok := value.(encoding.TextMarshaler); ok { + encoded, err := marshaler.MarshalText() + if err == nil { + return string(encoded) + } + } + rv := reflect.ValueOf(value) + if !rv.IsValid() { + return nil + } //nolint:exhaustive // only composite kinds need normalization; scalars can pass through unchanged switch rv.Kind() { + case reflect.Pointer: + if rv.IsNil() { + return nil + } + return normalizeAuditValue(rv.Elem().Interface()) + case reflect.Struct: + structType := rv.Type() + normalized := make(map[string]any, structType.NumField()) + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + if !field.IsExported() { + continue + } + opts, ok := parseAuditFieldOptions(field) + if !ok { + continue + } + normalized[opts.name] = normalizeAuditValue(rv.Field(i).Interface()) + } + return normalized case reflect.Map: if rv.IsNil() { return nil diff --git a/service/logger/audit/schema.go b/service/logger/audit/schema.go index e68b743322..72c2b581d8 100644 --- a/service/logger/audit/schema.go +++ b/service/logger/audit/schema.go @@ -98,10 +98,8 @@ func parseAuditFieldOptions(field reflect.StructField) (auditFieldOptions, bool) return auditFieldOptions{}, false } - name, reserved, extensible := parseAuditTag(tag) - if name == "" { - name = parseJSONFieldName(field) - } + reserved, extensible := parseAuditTag(tag) + name := parseJSONFieldName(field) if name == "" { return auditFieldOptions{}, false } @@ -113,26 +111,17 @@ func parseAuditFieldOptions(field reflect.StructField) (auditFieldOptions, bool) }, true } -func parseAuditTag(tag string) (string, bool, bool) { - if tag == "" { - return "", false, false - } - - parts := strings.Split(tag, ",") - name := parts[0] - var ( - reserved bool - extensible bool - ) - for _, option := range parts[1:] { - switch option { - case "reserved": - reserved = true - case "extensible": - extensible = true - } +func parseAuditTag(tag string) (bool, bool) { + switch tag { + case "": + return false, false + case "reserved": + return true, false + case "extensible": + return false, true + default: + return false, false } - return name, reserved, extensible } func parseJSONFieldName(field reflect.StructField) string { diff --git a/service/logger/audit/utils.go b/service/logger/audit/utils.go index 79d06eb433..743b598072 100644 --- a/service/logger/audit/utils.go +++ b/service/logger/audit/utils.go @@ -16,70 +16,36 @@ type auditEventMetadata map[string]any // event type EventObject struct { - Object auditEventObject `json:"object" audit:"object"` - Action eventAction `json:"action" audit:"action"` - Actor auditEventActor `json:"actor" audit:"actor"` - EventMetaData auditEventMetadata `json:"eventMetaData" audit:"eventMetaData,extensible"` - ClientInfo eventClientInfo `json:"clientInfo" audit:"clientInfo"` + Object auditEventObject `json:"object"` + Action eventAction `json:"action"` + Actor auditEventActor `json:"actor"` + EventMetaData auditEventMetadata `json:"eventMetaData" audit:"extensible"` + ClientInfo eventClientInfo `json:"clientInfo"` - Original map[string]any `json:"original,omitempty" audit:"original,extensible"` - Updated map[string]any `json:"updated,omitempty" audit:"updated,extensible"` - RequestID uuid.UUID `json:"requestId" audit:"requestID,reserved"` - Timestamp string `json:"timestamp" audit:"timestamp,reserved"` + Original map[string]any `json:"original,omitempty" audit:"extensible"` + Updated map[string]any `json:"updated,omitempty" audit:"extensible"` + RequestID uuid.UUID `json:"requestID" audit:"reserved"` + Timestamp string `json:"timestamp" audit:"reserved"` } func (e EventObject) LogValue() slog.Value { - return slog.GroupValue( - slog.Any("object", e.Object), - slog.Any("action", e.Action), - slog.Any("actor", e.Actor), - slog.Any("eventMetaData", e.EventMetaData), - slog.Any("clientInfo", e.ClientInfo), - slog.Any("original", e.Original), - slog.Any("updated", e.Updated), - slog.String("requestID", e.RequestID.String()), - slog.String("timestamp", e.Timestamp)) -} - -func (e EventObject) logMap() map[string]any { - return map[string]any{ - "object": map[string]any{ - "type": e.Object.Type.String(), - "id": e.Object.ID, - "name": e.Object.Name, - "attributes": map[string]any{ - "assertions": e.Object.Attributes.Assertions, - "attrs": e.Object.Attributes.Attrs, - "permissions": e.Object.Attributes.Permissions, - }, - }, - "action": map[string]any{ - "type": e.Action.Type.String(), - "result": e.Action.Result.String(), - }, - "actor": map[string]any{ - "id": e.Actor.ID, - "attributes": e.Actor.Attributes, - }, - "eventMetaData": normalizeAuditValue(e.EventMetaData), - "clientInfo": map[string]any{ - "userAgent": e.ClientInfo.UserAgent, - "platform": e.ClientInfo.Platform, - "requestIP": e.ClientInfo.RequestIP, - }, - "original": normalizeAuditValue(e.Original), - "updated": normalizeAuditValue(e.Updated), - "requestID": e.RequestID.String(), - "timestamp": e.Timestamp, + return slog.AnyValue(e.emittedPayloadMap()) +} + +func (e EventObject) emittedPayloadMap() map[string]any { + entry, ok := normalizeAuditValue(e).(map[string]any) + if !ok { + panic("normalized audit payload must be a map") } + return entry } // event.object type auditEventObject struct { - Type ObjectType `json:"type" audit:"type,reserved"` - ID string `json:"id" audit:"id"` - Name string `json:"name,omitempty" audit:"name"` - Attributes eventObjectAttributes `json:"attributes,omitempty" audit:"attributes"` + Type ObjectType `json:"type" audit:"reserved"` + ID string `json:"id"` + Name string `json:"name,omitempty"` + Attributes eventObjectAttributes `json:"attributes,omitempty"` } func (e auditEventObject) LogValue() slog.Value { @@ -106,8 +72,8 @@ func (e eventObjectAttributes) LogValue() slog.Value { // event.action type eventAction struct { - Type ActionType `json:"type" audit:"type,reserved"` - Result ActionResult `json:"result" audit:"result,reserved"` + Type ActionType `json:"type" audit:"reserved"` + Result ActionResult `json:"result" audit:"reserved"` } func (e eventAction) LogValue() slog.Value { @@ -118,8 +84,8 @@ func (e eventAction) LogValue() slog.Value { // event.actor type auditEventActor struct { - ID string `json:"id" audit:"id,reserved"` - Attributes []any `json:"attributes" audit:"attributes"` + ID string `json:"id" audit:"reserved"` + Attributes []any `json:"attributes"` } func (e auditEventActor) LogValue() slog.Value { @@ -130,9 +96,9 @@ func (e auditEventActor) LogValue() slog.Value { // event.clientInfo type eventClientInfo struct { - UserAgent string `json:"userAgent" audit:"userAgent,reserved"` - Platform string `json:"platform" audit:"platform,reserved"` - RequestIP string `json:"requestIp" audit:"requestIP,reserved"` + UserAgent string `json:"userAgent" audit:"reserved"` + Platform string `json:"platform" audit:"reserved"` + RequestIP string `json:"requestIP" audit:"reserved"` } func (e eventClientInfo) LogValue() slog.Value { diff --git a/service/logger/logger.go b/service/logger/logger.go index eec308ac5b..a96ec7a7a9 100644 --- a/service/logger/logger.go +++ b/service/logger/logger.go @@ -11,7 +11,7 @@ // Original map[string]interface{} `json:"original,omitempty"` // Updated map[string]interface{} `json:"updated,omitempty"` -// RequestID uuid.UUID `json:"requestId"` +// RequestID uuid.UUID `json:"requestID"` // Timestamp string `json:"timestamp"` // } From 6e2e19101d5741130837a110ab90a98a8e645173 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Mon, 4 May 2026 08:20:00 -0700 Subject: [PATCH 16/20] fix(audit): return CodeUnauthenticated for IPC auth failures and reject unknown audit tags Return connect.CodeUnauthenticated instead of CodeInternal when IPC token rehydration fails, since a malformed token is an auth failure not a server fault. Fail fast on unrecognized audit struct tag values (e.g. a typo like audit:"resreved") by returning an error from parseAuditTag, propagated through schema construction to panic at startup. Signed-off-by: jakedoublev --- service/internal/auth/authn.go | 2 +- .../authn_ipc_metadata_interceptor_test.go | 2 +- service/logger/audit/enrichment.go | 2 +- service/logger/audit/schema.go | 28 +++++++++++-------- service/logger/audit/schema_test.go | 10 +++++++ 5 files changed, 30 insertions(+), 14 deletions(-) diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index 096ee2e2aa..a81ac4be87 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -514,7 +514,7 @@ func (a Authentication) IPCUnaryServerInterceptor() connect.UnaryInterceptorFunc } nextCtx, err = rehydrateIPCAuthContext(nextCtx, a.logger) if err != nil { - return nil, connect.NewError(connect.CodeInternal, errors.New("failed to rehydrate IPC authentication context")) + return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("invalid IPC authentication context")) } return next(nextCtx, req) }) diff --git a/service/internal/auth/authn_ipc_metadata_interceptor_test.go b/service/internal/auth/authn_ipc_metadata_interceptor_test.go index 49238d67ad..2bd5d8b486 100644 --- a/service/internal/auth/authn_ipc_metadata_interceptor_test.go +++ b/service/internal/auth/authn_ipc_metadata_interceptor_test.go @@ -350,7 +350,7 @@ func TestIPCUnaryServerInterceptor_Integration(t *testing.T) { var connectErr *connect.Error require.ErrorAs(t, err, &connectErr) - assert.Equal(t, connect.CodeInternal, connectErr.Code()) + assert.Equal(t, connect.CodeUnauthenticated, connectErr.Code()) }) } diff --git a/service/logger/audit/enrichment.go b/service/logger/audit/enrichment.go index 3179d7e845..7502e0349f 100644 --- a/service/logger/audit/enrichment.go +++ b/service/logger/audit/enrichment.go @@ -118,7 +118,7 @@ func normalizeAuditValue(value any) any { if !field.IsExported() { continue } - opts, ok := parseAuditFieldOptions(field) + opts, ok, _ := parseAuditFieldOptions(field) if !ok { continue } diff --git a/service/logger/audit/schema.go b/service/logger/audit/schema.go index 72c2b581d8..91f417d233 100644 --- a/service/logger/audit/schema.go +++ b/service/logger/audit/schema.go @@ -62,7 +62,10 @@ func addAuditSchemaFields(parent *auditPathSchema, structType reflect.Type) erro continue } - opts, ok := parseAuditFieldOptions(field) + opts, ok, err := parseAuditFieldOptions(field) + if err != nil { + return err + } if !ok { continue } @@ -92,35 +95,38 @@ func addAuditSchemaFields(parent *auditPathSchema, structType reflect.Type) erro return nil } -func parseAuditFieldOptions(field reflect.StructField) (auditFieldOptions, bool) { +func parseAuditFieldOptions(field reflect.StructField) (auditFieldOptions, bool, error) { tag := field.Tag.Get("audit") if tag == "-" { - return auditFieldOptions{}, false + return auditFieldOptions{}, false, nil } - reserved, extensible := parseAuditTag(tag) + reserved, extensible, err := parseAuditTag(tag) + if err != nil { + return auditFieldOptions{}, false, fmt.Errorf("field %s: %w", field.Name, err) + } name := parseJSONFieldName(field) if name == "" { - return auditFieldOptions{}, false + return auditFieldOptions{}, false, nil } return auditFieldOptions{ name: name, reserved: reserved, extensible: extensible, - }, true + }, true, nil } -func parseAuditTag(tag string) (bool, bool) { +func parseAuditTag(tag string) (bool, bool, error) { switch tag { case "": - return false, false + return false, false, nil case "reserved": - return true, false + return true, false, nil case "extensible": - return false, true + return false, true, nil default: - return false, false + return false, false, fmt.Errorf("unknown audit tag %q", tag) } } diff --git a/service/logger/audit/schema_test.go b/service/logger/audit/schema_test.go index e326af8475..10229e948d 100644 --- a/service/logger/audit/schema_test.go +++ b/service/logger/audit/schema_test.go @@ -1,6 +1,7 @@ package audit import ( + "reflect" "testing" "github.com/stretchr/testify/require" @@ -57,3 +58,12 @@ func TestValidateClaimDestinationPath(t *testing.T) { require.ErrorIs(t, err, ErrUnknownAuditPath) }) } + +func TestBuildAuditPathSchemaRejectsUnknownTags(t *testing.T) { + type badStruct struct { + Field string `json:"field" audit:"resreved"` + } + _, err := buildAuditPathSchema(reflect.TypeOf(badStruct{})) + require.Error(t, err) + require.ErrorContains(t, err, "unknown audit tag") +} From 21dff6b1d7d7c7af5ad4305a81587b5595c9df3a Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Mon, 4 May 2026 08:26:12 -0700 Subject: [PATCH 17/20] fix(audit): reject overlapping JWT claim mapping paths at config time Two mappings where one path is a prefix of another (e.g. "banana" and "banana.kiwi.mango") would pass individual validation but cause dotnotation.Set collisions at runtime. Detect this during Config.Validate() so the service fails fast at startup instead of emitting partially enriched audit logs per-request. Signed-off-by: jakedoublev --- service/logger/audit/config.go | 36 ++++++++++++++++++++++++++++- service/logger/audit/schema.go | 2 ++ service/logger/audit/schema_test.go | 34 +++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/service/logger/audit/config.go b/service/logger/audit/config.go index 2c547620b7..04abd6562c 100644 --- a/service/logger/audit/config.go +++ b/service/logger/audit/config.go @@ -1,6 +1,9 @@ package audit -import "fmt" +import ( + "fmt" + "strings" +) // JWTClaimMapping maps a JWT claim to a destination in the emitted audit log. // The destination path uses dot notation over the audit log output shape, e.g. @@ -29,5 +32,36 @@ func (c Config) Validate() error { return fmt.Errorf("jwt_claim_mappings[%d].path: %w", idx, err) } } + if err := validateNoOverlappingPaths(c.JWTClaimMappings); err != nil { + return err + } + return nil +} + +func validateNoOverlappingPaths(mappings []JWTClaimMapping) error { + for i, a := range mappings { + aParts := strings.Split(a.Path, ".") + for j, b := range mappings { + if i == j { + continue + } + bParts := strings.Split(b.Path, ".") + if isPathPrefix(aParts, bParts) { + return fmt.Errorf("%w: %q is a prefix of %q", ErrOverlappingAuditPaths, a.Path, b.Path) + } + } + } return nil } + +func isPathPrefix(short, long []string) bool { + if len(short) >= len(long) { + return false + } + for i, s := range short { + if s != long[i] { + return false + } + } + return true +} diff --git a/service/logger/audit/schema.go b/service/logger/audit/schema.go index 91f417d233..59f2b8a5e2 100644 --- a/service/logger/audit/schema.go +++ b/service/logger/audit/schema.go @@ -14,6 +14,8 @@ var ( ErrUnknownAuditPath = errors.New("unknown audit path") // ErrAuditContainerPath indicates a claim destination resolves to a container instead of a writable leaf. ErrAuditContainerPath = errors.New("audit path resolves to a container") + // ErrOverlappingAuditPaths indicates two claim destinations conflict because one is a prefix of the other. + ErrOverlappingAuditPaths = errors.New("overlapping audit paths") auditClaimDestinationSchema = mustBuildAuditPathSchema(reflect.TypeOf(EventObject{})) ) diff --git a/service/logger/audit/schema_test.go b/service/logger/audit/schema_test.go index 10229e948d..a31f3f6036 100644 --- a/service/logger/audit/schema_test.go +++ b/service/logger/audit/schema_test.go @@ -59,6 +59,40 @@ func TestValidateClaimDestinationPath(t *testing.T) { }) } +func TestValidateNoOverlappingPaths(t *testing.T) { + t.Run("allows sibling paths", func(t *testing.T) { + err := validateNoOverlappingPaths([]JWTClaimMapping{ + {Claim: "sub", Path: "banana.kiwi"}, + {Claim: "email", Path: "banana.mango"}, + }) + require.NoError(t, err) + }) + + t.Run("rejects short prefix of long", func(t *testing.T) { + err := validateNoOverlappingPaths([]JWTClaimMapping{ + {Claim: "sub", Path: "banana"}, + {Claim: "email", Path: "banana.kiwi.mango"}, + }) + require.ErrorIs(t, err, ErrOverlappingAuditPaths) + }) + + t.Run("rejects long prefix of short", func(t *testing.T) { + err := validateNoOverlappingPaths([]JWTClaimMapping{ + {Claim: "email", Path: "banana.kiwi.mango"}, + {Claim: "sub", Path: "banana"}, + }) + require.ErrorIs(t, err, ErrOverlappingAuditPaths) + }) + + t.Run("allows identical depth different leaves", func(t *testing.T) { + err := validateNoOverlappingPaths([]JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.requester.sub"}, + {Claim: "email", Path: "eventMetaData.requester.email"}, + }) + require.NoError(t, err) + }) +} + func TestBuildAuditPathSchemaRejectsUnknownTags(t *testing.T) { type badStruct struct { Field string `json:"field" audit:"resreved"` From f143cc235bd60fadbc629f72094f5b3863382b90 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Mon, 4 May 2026 08:29:00 -0700 Subject: [PATCH 18/20] refactor(audit): remove ok+err return from parseAuditFieldOptions Return (auditFieldOptions, error) instead of (*auditFieldOptions, error) to avoid nil-nil returns flagged by nilnil linter. A zero-value opts (empty name) signals "skip this field." Signed-off-by: jakedoublev --- service/logger/audit/enrichment.go | 4 ++-- service/logger/audit/schema.go | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/service/logger/audit/enrichment.go b/service/logger/audit/enrichment.go index 7502e0349f..ee53785050 100644 --- a/service/logger/audit/enrichment.go +++ b/service/logger/audit/enrichment.go @@ -118,8 +118,8 @@ func normalizeAuditValue(value any) any { if !field.IsExported() { continue } - opts, ok, _ := parseAuditFieldOptions(field) - if !ok { + opts, _ := parseAuditFieldOptions(field) + if opts.name == "" { continue } normalized[opts.name] = normalizeAuditValue(rv.Field(i).Interface()) diff --git a/service/logger/audit/schema.go b/service/logger/audit/schema.go index 59f2b8a5e2..8a04995eaa 100644 --- a/service/logger/audit/schema.go +++ b/service/logger/audit/schema.go @@ -64,11 +64,11 @@ func addAuditSchemaFields(parent *auditPathSchema, structType reflect.Type) erro continue } - opts, ok, err := parseAuditFieldOptions(field) + opts, err := parseAuditFieldOptions(field) if err != nil { return err } - if !ok { + if opts.name == "" { continue } if _, exists := parent.children[opts.name]; exists { @@ -97,26 +97,26 @@ func addAuditSchemaFields(parent *auditPathSchema, structType reflect.Type) erro return nil } -func parseAuditFieldOptions(field reflect.StructField) (auditFieldOptions, bool, error) { +func parseAuditFieldOptions(field reflect.StructField) (auditFieldOptions, error) { tag := field.Tag.Get("audit") if tag == "-" { - return auditFieldOptions{}, false, nil + return auditFieldOptions{}, nil } reserved, extensible, err := parseAuditTag(tag) if err != nil { - return auditFieldOptions{}, false, fmt.Errorf("field %s: %w", field.Name, err) + return auditFieldOptions{}, fmt.Errorf("field %s: %w", field.Name, err) } name := parseJSONFieldName(field) if name == "" { - return auditFieldOptions{}, false, nil + return auditFieldOptions{}, nil } return auditFieldOptions{ name: name, reserved: reserved, extensible: extensible, - }, true, nil + }, nil } func parseAuditTag(tag string) (bool, bool, error) { From e924b91cccfdc4d2c419ef6564fe20894235b437 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Mon, 4 May 2026 10:15:47 -0700 Subject: [PATCH 19/20] fix(audit): validate full dot-path syntax before extensible fallback Hoist empty-segment validation before the schema traversal loop so malformed paths like "banana..mango" or "banana." are rejected even when the first segment triggers an extensible early return. Also remove redundant md.Get("Authorization") fallback in IPC token extraction since gRPC metadata keys are case-insensitive. Signed-off-by: jakedoublev --- service/internal/auth/authn.go | 3 --- service/logger/audit/schema.go | 6 ++++-- service/logger/audit/schema_test.go | 22 ++++++++++++++++++++-- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index a81ac4be87..449e252689 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -474,9 +474,6 @@ func rawAccessTokenFromIncomingMetadata(md metadata.MD) string { } authHeaders := md.Get("authorization") - if len(authHeaders) == 0 { - authHeaders = md.Get("Authorization") - } if len(authHeaders) == 0 || authHeaders[0] == "" { return "" } diff --git a/service/logger/audit/schema.go b/service/logger/audit/schema.go index 8a04995eaa..7dee297e7d 100644 --- a/service/logger/audit/schema.go +++ b/service/logger/audit/schema.go @@ -166,12 +166,14 @@ func validateClaimDestinationPath(path string) error { } segments := strings.Split(path, ".") - current := auditClaimDestinationSchema - for idx, segment := range segments { + for _, segment := range segments { if segment == "" { return fmt.Errorf("%w: %s", ErrUnknownAuditPath, path) } + } + current := auditClaimDestinationSchema + for idx, segment := range segments { child, ok := current.children[segment] if !ok { if current.extensible { diff --git a/service/logger/audit/schema_test.go b/service/logger/audit/schema_test.go index a31f3f6036..8f378e2670 100644 --- a/service/logger/audit/schema_test.go +++ b/service/logger/audit/schema_test.go @@ -53,8 +53,26 @@ func TestValidateClaimDestinationPath(t *testing.T) { require.ErrorIs(t, err, ErrUnknownAuditPath) }) - t.Run("rejects leading dot paths", func(t *testing.T) { - err := validateClaimDestinationPath(".banana") + t.Run("rejects malformed dot paths", func(t *testing.T) { + for _, path := range []string{ + ".banana", // leading dot + "banana.", // trailing dot + "banana..mango", // consecutive dots + ".", // single dot + "..", // double dot + "a.b..c.d", // mid-path empty segment + ".a.b", // leading dot with valid tail + "eventMetaData.", // trailing dot after known node + } { + t.Run(path, func(t *testing.T) { + err := validateClaimDestinationPath(path) + require.ErrorIs(t, err, ErrUnknownAuditPath) + }) + } + }) + + t.Run("rejects empty path", func(t *testing.T) { + err := validateClaimDestinationPath("") require.ErrorIs(t, err, ErrUnknownAuditPath) }) } From 520fbcaba8bbfeef9def4897570ce8bbe6f81286 Mon Sep 17 00:00:00 2001 From: jakedoublev Date: Mon, 4 May 2026 10:46:24 -0700 Subject: [PATCH 20/20] fix(audit): reject duplicate destination paths in claim mappings Two claims mapped to the same path would silently overwrite each other at runtime. Detect this in validateNoOverlappingPaths and reject at config time. Also refactored pair iteration to check each pair once. Signed-off-by: jakedoublev --- service/logger/audit/config.go | 8 +++++++- service/logger/audit/schema_test.go | 8 ++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/service/logger/audit/config.go b/service/logger/audit/config.go index 04abd6562c..220fa1f971 100644 --- a/service/logger/audit/config.go +++ b/service/logger/audit/config.go @@ -42,13 +42,19 @@ func validateNoOverlappingPaths(mappings []JWTClaimMapping) error { for i, a := range mappings { aParts := strings.Split(a.Path, ".") for j, b := range mappings { - if i == j { + if i >= j { continue } + if a.Path == b.Path { + return fmt.Errorf("%w: duplicate path %q", ErrOverlappingAuditPaths, a.Path) + } bParts := strings.Split(b.Path, ".") if isPathPrefix(aParts, bParts) { return fmt.Errorf("%w: %q is a prefix of %q", ErrOverlappingAuditPaths, a.Path, b.Path) } + if isPathPrefix(bParts, aParts) { + return fmt.Errorf("%w: %q is a prefix of %q", ErrOverlappingAuditPaths, b.Path, a.Path) + } } } return nil diff --git a/service/logger/audit/schema_test.go b/service/logger/audit/schema_test.go index 8f378e2670..06bded902a 100644 --- a/service/logger/audit/schema_test.go +++ b/service/logger/audit/schema_test.go @@ -109,6 +109,14 @@ func TestValidateNoOverlappingPaths(t *testing.T) { }) require.NoError(t, err) }) + + t.Run("rejects duplicate destination paths", func(t *testing.T) { + err := validateNoOverlappingPaths([]JWTClaimMapping{ + {Claim: "sub", Path: "eventMetaData.requester.sub"}, + {Claim: "email", Path: "eventMetaData.requester.sub"}, + }) + require.ErrorIs(t, err, ErrOverlappingAuditPaths) + }) } func TestBuildAuditPathSchemaRejectsUnknownTags(t *testing.T) {