diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 8d16b784f7..7f32fc4f1a 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -19,6 +19,10 @@ import ( "connectrpc.com/validate" "github.com/go-chi/cors" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + attrconnect "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" + nsconnect "github.com/opentdf/platform/protocol/go/policy/namespaces/namespacesconnect" + smconnect "github.com/opentdf/platform/protocol/go/policy/subjectmapping/subjectmappingconnect" + unsafeconnect "github.com/opentdf/platform/protocol/go/policy/unsafe/unsafeconnect" "github.com/opentdf/platform/sdk" sdkAudit "github.com/opentdf/platform/sdk/audit" "github.com/opentdf/platform/service/internal/auth" @@ -28,6 +32,7 @@ import ( "github.com/opentdf/platform/service/logger/audit" ctxAuth "github.com/opentdf/platform/service/pkg/auth" "github.com/opentdf/platform/service/pkg/cache" + "github.com/opentdf/platform/service/pkg/enumnormalize" "github.com/opentdf/platform/service/tracing" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -62,8 +67,9 @@ type Config struct { WellKnownConfigRegister func(namespace string, config any) error `mapstructure:"-" json:"-"` // Programmatic interceptors injected at startup (not loaded from config) - ExtraConnectInterceptors []connect.Interceptor `mapstructure:"-" json:"-"` - ExtraIPCInterceptors []connect.Interceptor `mapstructure:"-" json:"-"` + ExtraConnectInterceptors []connect.Interceptor `mapstructure:"-" json:"-"` + ExtraIPCInterceptors []connect.Interceptor `mapstructure:"-" json:"-"` + ExtraHTTPMiddleware []func(http.Handler) http.Handler `mapstructure:"-" json:"-"` // Port to listen on Port int `mapstructure:"port" json:"port" default:"8080"` Host string `mapstructure:"host,omitempty" json:"host"` @@ -392,6 +398,42 @@ func newHTTPServer(c Config, connectRPC http.Handler, originalGrpcGateway http.H tc *tls.Config ) + // Normalize shorthand enum names (e.g. "IN" → "SUBJECT_MAPPING_OPERATOR_ENUM_IN") + // in JSON request bodies before ConnectRPC deserializes them. Accepts the + // suffix after the enum type prefix, case-insensitive, while full canonical + // names continue to work unchanged. See: opentdf/platform#3338 + connectRPC = enumnormalize.NewMiddleware( + []enumnormalize.EnumFieldRule{ + // Subject Mapping enums + {JSONField: "operator", Prefix: "SUBJECT_MAPPING_OPERATOR_ENUM_"}, + {JSONField: "booleanOperator", Prefix: "CONDITION_BOOLEAN_TYPE_ENUM_"}, + // Attribute rule type + {JSONField: "rule", Prefix: "ATTRIBUTE_RULE_TYPE_ENUM_"}, + // Active state filter (list requests) + {JSONField: "state", Prefix: "ACTIVE_STATE_ENUM_"}, + }, + []string{ + // Subject Mapping RPCs + smconnect.SubjectMappingServiceCreateSubjectMappingProcedure, + smconnect.SubjectMappingServiceCreateSubjectConditionSetProcedure, + smconnect.SubjectMappingServiceUpdateSubjectConditionSetProcedure, + // Attribute RPCs (rule + state) + attrconnect.AttributesServiceCreateAttributeProcedure, + attrconnect.AttributesServiceUpdateAttributeProcedure, + attrconnect.AttributesServiceListAttributesProcedure, + attrconnect.AttributesServiceListAttributeValuesProcedure, + // Namespace RPCs (state) + nsconnect.NamespaceServiceListNamespacesProcedure, + // Unsafe RPCs (rule) + unsafeconnect.UnsafeServiceUnsafeUpdateAttributeProcedure, + }, + )(connectRPC) + + // Apply extra HTTP middleware injected by downstream consumers (e.g. DSP). + for _, mw := range c.ExtraHTTPMiddleware { + connectRPC = mw(connectRPC) + } + // Adds deprecation header to any grpcGateway responses. var grpcGateway http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { grpcRW := &grpcGatewayResponseWriter{w: w, code: http.StatusOK} diff --git a/service/pkg/enumnormalize/enumnormalize.go b/service/pkg/enumnormalize/enumnormalize.go new file mode 100644 index 0000000000..25f133475a --- /dev/null +++ b/service/pkg/enumnormalize/enumnormalize.go @@ -0,0 +1,129 @@ +package enumnormalize + +import ( + "bytes" + "encoding/json" + "io" + "strings" +) + +// EnumFieldRule maps a JSON field name to the prefix that protobuf requires. +// When the middleware encounters a string value in a matching field that does +// not already carry the prefix, it prepends the prefix so that protojson +// recognises the canonical enum name. +type EnumFieldRule struct { + // JSONField is the protojson camelCase field name (e.g. "operator", "booleanOperator"). + JSONField string + // Prefix is the proto enum type prefix including trailing underscore + // (e.g. "SUBJECT_MAPPING_OPERATOR_ENUM_"). + Prefix string + // ParentField optionally scopes this rule to only match when JSONField + // appears inside an object that is a direct child of a key named + // ParentField (at any depth). This disambiguates cases where multiple + // enum types share the same field name (e.g. "type") but live under + // different parent keys (e.g. "contentExtractors" vs "tagProcessors"). + // When empty, the rule matches JSONField at any position (original behavior). + ParentField string +} + +// ruleLookup stores pre-built lookup tables for fast matching. +type ruleLookup struct { + // global maps field name → prefix for rules with no ParentField. + global map[string]string + // scoped maps parentField → (field name → prefix) for parent-scoped rules. + scoped map[string]map[string]string +} + +// buildRuleLookup creates a ruleLookup from a set of rules. +func buildRuleLookup(rules []EnumFieldRule) ruleLookup { + rl := ruleLookup{ + global: make(map[string]string), + scoped: make(map[string]map[string]string), + } + for _, r := range rules { + if r.ParentField == "" { + rl.global[r.JSONField] = r.Prefix + } else { + if rl.scoped[r.ParentField] == nil { + rl.scoped[r.ParentField] = make(map[string]string) + } + rl.scoped[r.ParentField][r.JSONField] = r.Prefix + } + } + return rl +} + +// normalizeJSON rewrites shorthand enum string values in body according to +// the configured rules. Values that already carry the full prefix, numeric +// values, and fields not covered by any rule pass through unchanged. +func normalizeJSON(body []byte, rl ruleLookup) ([]byte, error) { + if len(body) == 0 || (len(rl.global) == 0 && len(rl.scoped) == 0) { + return body, nil + } + + // Use json.Decoder with UseNumber to preserve numeric precision + // (avoids float64 conversion of large int64 values). + decoder := json.NewDecoder(bytes.NewReader(body)) + decoder.UseNumber() + + var parsed any + if err := decoder.Decode(&parsed); err != nil { + // Not valid JSON — pass through and let ConnectRPC surface the error. + return body, nil //nolint:nilerr // intentional: invalid JSON is not our error to report + } + + // Ensure the entire body is a single JSON value. If there are trailing + // tokens (e.g. `{"a":1}{"b":2}`), return the original body so ConnectRPC + // can reject the malformed input rather than silently dropping the tail. + var trailing any + if err := decoder.Decode(&trailing); err != io.EOF { + return body, nil + } + + normalizeValue(parsed, rl, "") + + return json.Marshal(parsed) +} + +// normalizeValue recursively walks a decoded JSON value, normalizing string +// enum fields according to the lookup rules. parentKey tracks the key under +// which the current value was found, enabling parent-scoped rules. +func normalizeValue(v any, rl ruleLookup, parentKey string) { + switch val := v.(type) { + case map[string]any: + for key, child := range val { + // Check global rules (no parent scope) + if prefix, ok := rl.global[key]; ok { + if s, isStr := child.(string); isStr { + val[key] = applyPrefix(s, prefix) + } + } + // Check parent-scoped rules + if scopedFields, hasParent := rl.scoped[parentKey]; hasParent { + if scopedPrefix, hasField := scopedFields[key]; hasField { + if s, isStr := child.(string); isStr { + val[key] = applyPrefix(s, scopedPrefix) + } + } + } + normalizeValue(child, rl, key) + } + case []any: + // Array elements inherit the parent key so that scoped rules work + // through arrays (e.g. "contentExtractors": [{"type": "..."}]). + for _, item := range val { + normalizeValue(item, rl, parentKey) + } + } +} + +// applyPrefix prepends prefix to value if it is not already present +// (case-insensitive check). The value is upper-cased before comparison and +// before prepending so that "in" and "IN" both resolve correctly. +func applyPrefix(value, prefix string) string { + upper := strings.ToUpper(value) + if strings.HasPrefix(upper, strings.ToUpper(prefix)) { + return upper + } + return prefix + upper +} diff --git a/service/pkg/enumnormalize/enumnormalize_test.go b/service/pkg/enumnormalize/enumnormalize_test.go new file mode 100644 index 0000000000..4d13ca384f --- /dev/null +++ b/service/pkg/enumnormalize/enumnormalize_test.go @@ -0,0 +1,480 @@ +package enumnormalize + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var allLookup = buildRuleLookup([]EnumFieldRule{ + {JSONField: "operator", Prefix: "SUBJECT_MAPPING_OPERATOR_ENUM_"}, + {JSONField: "booleanOperator", Prefix: "CONDITION_BOOLEAN_TYPE_ENUM_"}, + {JSONField: "rule", Prefix: "ATTRIBUTE_RULE_TYPE_ENUM_"}, + {JSONField: "state", Prefix: "ACTIVE_STATE_ENUM_"}, +}) + +func TestNormalizeJSON_ShorthandOperators(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "IN shorthand", + input: `{"operator":"IN"}`, + expected: `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_IN"}`, + }, + { + name: "NOT_IN shorthand", + input: `{"operator":"NOT_IN"}`, + expected: `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN"}`, + }, + { + name: "IN_CONTAINS shorthand", + input: `{"operator":"IN_CONTAINS"}`, + expected: `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_IN_CONTAINS"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := normalizeJSON([]byte(tt.input), allLookup) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(out)) + }) + } +} + +func TestNormalizeJSON_ShorthandBooleanOperators(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "AND shorthand", + input: `{"booleanOperator":"AND"}`, + expected: `{"booleanOperator":"CONDITION_BOOLEAN_TYPE_ENUM_AND"}`, + }, + { + name: "OR shorthand", + input: `{"booleanOperator":"OR"}`, + expected: `{"booleanOperator":"CONDITION_BOOLEAN_TYPE_ENUM_OR"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := normalizeJSON([]byte(tt.input), allLookup) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(out)) + }) + } +} + +func TestNormalizeJSON_ShorthandAttributeRuleType(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "ALL_OF shorthand", + input: `{"rule":"ALL_OF"}`, + expected: `{"rule":"ATTRIBUTE_RULE_TYPE_ENUM_ALL_OF"}`, + }, + { + name: "ANY_OF shorthand", + input: `{"rule":"ANY_OF"}`, + expected: `{"rule":"ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF"}`, + }, + { + name: "HIERARCHY shorthand", + input: `{"rule":"HIERARCHY"}`, + expected: `{"rule":"ATTRIBUTE_RULE_TYPE_ENUM_HIERARCHY"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := normalizeJSON([]byte(tt.input), allLookup) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(out)) + }) + } +} + +func TestNormalizeJSON_ShorthandActiveState(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "ACTIVE shorthand", + input: `{"state":"ACTIVE"}`, + expected: `{"state":"ACTIVE_STATE_ENUM_ACTIVE"}`, + }, + { + name: "INACTIVE shorthand", + input: `{"state":"INACTIVE"}`, + expected: `{"state":"ACTIVE_STATE_ENUM_INACTIVE"}`, + }, + { + name: "ANY shorthand", + input: `{"state":"ANY"}`, + expected: `{"state":"ACTIVE_STATE_ENUM_ANY"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := normalizeJSON([]byte(tt.input), allLookup) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(out)) + }) + } +} + +func TestNormalizeJSON_CaseInsensitive(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "lowercase in", + input: `{"operator":"in"}`, + expected: `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_IN"}`, + }, + { + name: "lowercase and", + input: `{"booleanOperator":"and"}`, + expected: `{"booleanOperator":"CONDITION_BOOLEAN_TYPE_ENUM_AND"}`, + }, + { + name: "mixed case Not_In", + input: `{"operator":"Not_In"}`, + expected: `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := normalizeJSON([]byte(tt.input), allLookup) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(out)) + }) + } +} + +func TestNormalizeJSON_FullCanonicalNamesPassThrough(t *testing.T) { + tests := []struct { + name string + input string + }{ + { + name: "full operator name", + input: `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_IN"}`, + }, + { + name: "full boolean operator name", + input: `{"booleanOperator":"CONDITION_BOOLEAN_TYPE_ENUM_AND"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := normalizeJSON([]byte(tt.input), allLookup) + require.NoError(t, err) + assert.JSONEq(t, tt.input, string(out)) + }) + } +} + +func TestNormalizeJSON_NumericValuesPassThrough(t *testing.T) { + tests := []struct { + name string + input string + }{ + { + name: "operator 1 (IN) and booleanOperator 2 (OR)", + input: `{"operator":1,"booleanOperator":2}`, + }, + { + name: "operator 3 (IN_CONTAINS) and booleanOperator 1 (AND)", + input: `{"operator":3,"booleanOperator":1}`, + }, + { + name: "operator 2 (NOT_IN)", + input: `{"operator":2}`, + }, + { + name: "rule 1 (ALL_OF)", + input: `{"rule":1}`, + }, + { + name: "state 1 (ACTIVE)", + input: `{"state":1}`, + }, + { + name: "numeric zero (UNSPECIFIED) passes through", + input: `{"operator":0}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := normalizeJSON([]byte(tt.input), allLookup) + require.NoError(t, err) + assert.JSONEq(t, tt.input, string(out)) + }) + } +} + +func TestNormalizeJSON_NumericValuesInNestedStructure(t *testing.T) { + // Simulates the JSON format that was previously used in documentation: + // numeric enum codes instead of string names. + input := `{ + "subjectConditionSet": { + "subjectSets": [{ + "conditionGroups": [{ + "booleanOperator": 1, + "conditions": [ + { + "subjectExternalSelectorValue": ".email", + "operator": 3, + "subjectExternalValues": ["@example.com"] + }, + { + "subjectExternalSelectorValue": ".role", + "operator": 1, + "subjectExternalValues": ["admin"] + } + ] + }] + }] + } + }` + + out, err := normalizeJSON([]byte(input), allLookup) + require.NoError(t, err) + // Numeric values should pass through unchanged — protojson natively + // accepts numeric enum representations. + assert.JSONEq(t, input, string(out)) +} + +func TestNormalizeJSON_UnknownValuesGetPrefixed(t *testing.T) { + // Unknown shorthand values get the prefix prepended; downstream + // protovalidate will reject them. + input := `{"operator":"FOOBAR"}` + expected := `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_FOOBAR"}` + out, err := normalizeJSON([]byte(input), allLookup) + require.NoError(t, err) + assert.JSONEq(t, expected, string(out)) +} + +func TestNormalizeJSON_UnrelatedFieldsUntouched(t *testing.T) { + input := `{"name":"test","description":"IN","operator":"IN"}` + out, err := normalizeJSON([]byte(input), allLookup) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(out, &result)) + + // "description" should NOT be prefixed — only "operator" is a rule field + assert.Equal(t, "IN", result["description"]) + assert.Equal(t, "SUBJECT_MAPPING_OPERATOR_ENUM_IN", result["operator"]) +} + +func TestNormalizeJSON_DeeplyNestedStructure(t *testing.T) { + // Simulates a CreateSubjectConditionSetRequest with nested condition groups + input := `{ + "subjectConditionSet": { + "subjectSets": [{ + "conditionGroups": [{ + "booleanOperator": "AND", + "conditions": [ + { + "subjectExternalSelectorValue": ".email", + "operator": "IN", + "subjectExternalValues": ["user@example.com"] + }, + { + "subjectExternalSelectorValue": ".groups", + "operator": "NOT_IN", + "subjectExternalValues": ["banned"] + } + ] + }] + }] + } + }` + + expected := `{ + "subjectConditionSet": { + "subjectSets": [{ + "conditionGroups": [{ + "booleanOperator": "CONDITION_BOOLEAN_TYPE_ENUM_AND", + "conditions": [ + { + "subjectExternalSelectorValue": ".email", + "operator": "SUBJECT_MAPPING_OPERATOR_ENUM_IN", + "subjectExternalValues": ["user@example.com"] + }, + { + "subjectExternalSelectorValue": ".groups", + "operator": "SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN", + "subjectExternalValues": ["banned"] + } + ] + }] + }] + } + }` + + out, err := normalizeJSON([]byte(input), allLookup) + require.NoError(t, err) + assert.JSONEq(t, expected, string(out)) +} + +func TestNormalizeJSON_MixedShorthandAndFullNames(t *testing.T) { + input := `{ + "conditionGroups": [{ + "booleanOperator": "CONDITION_BOOLEAN_TYPE_ENUM_OR", + "conditions": [ + {"operator": "IN"}, + {"operator": "SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN"} + ] + }] + }` + + expected := `{ + "conditionGroups": [{ + "booleanOperator": "CONDITION_BOOLEAN_TYPE_ENUM_OR", + "conditions": [ + {"operator": "SUBJECT_MAPPING_OPERATOR_ENUM_IN"}, + {"operator": "SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN"} + ] + }] + }` + + out, err := normalizeJSON([]byte(input), allLookup) + require.NoError(t, err) + assert.JSONEq(t, expected, string(out)) +} + +func TestNormalizeJSON_EmptyBody(t *testing.T) { + out, err := normalizeJSON([]byte{}, allLookup) + require.NoError(t, err) + assert.Empty(t, out) +} + +func TestNormalizeJSON_NoRules(t *testing.T) { + input := `{"operator":"IN"}` + out, err := normalizeJSON([]byte(input), ruleLookup{}) + require.NoError(t, err) + assert.Equal(t, input, string(out)) +} + +func TestNormalizeJSON_TrailingJSONTokensPassThrough(t *testing.T) { + // Multiple concatenated JSON values should pass through unchanged so + // ConnectRPC can reject the malformed input. + input := `{"operator":"IN"}{"extra":1}` + out, err := normalizeJSON([]byte(input), allLookup) + require.NoError(t, err) + assert.Equal(t, input, string(out)) +} + +func TestNormalizeJSON_InvalidJSON(t *testing.T) { + input := `not json at all` + out, err := normalizeJSON([]byte(input), allLookup) + require.NoError(t, err) + // Invalid JSON passes through unchanged + assert.Equal(t, input, string(out)) +} + +// Parent-scoped rule tests + +var scopedLookup = buildRuleLookup([]EnumFieldRule{ + // Different prefixes for the same "type" field, scoped by parent key + {JSONField: "type", Prefix: "CONTENT_EXTRACTOR_TYPE_", ParentField: "contentExtractors"}, + {JSONField: "type", Prefix: "TAG_PROCESSOR_TYPE_", ParentField: "tagProcessors"}, + // A global rule (no parent scope) for a different field + {JSONField: "state", Prefix: "ACTIVE_STATE_ENUM_"}, +}) + +func TestNormalizeJSON_ParentScopedRules(t *testing.T) { + input := `{ + "config": { + "v1": { + "contentExtractors": [{"type": "TIKA_CONTENT_EXTRACTION", "id": "ce1"}], + "tagProcessors": [{"type": "REQUIRED_TAGS", "id": "tp1"}] + } + } + }` + expected := `{ + "config": { + "v1": { + "contentExtractors": [{"type": "CONTENT_EXTRACTOR_TYPE_TIKA_CONTENT_EXTRACTION", "id": "ce1"}], + "tagProcessors": [{"type": "TAG_PROCESSOR_TYPE_REQUIRED_TAGS", "id": "tp1"}] + } + } + }` + + out, err := normalizeJSON([]byte(input), scopedLookup) + require.NoError(t, err) + assert.JSONEq(t, expected, string(out)) +} + +func TestNormalizeJSON_ParentScopedDoesNotMatchGlobally(t *testing.T) { + // "type" at top level should NOT be rewritten — it only matches under + // "contentExtractors" or "tagProcessors". + input := `{"type": "SOME_VALUE"}` + + out, err := normalizeJSON([]byte(input), scopedLookup) + require.NoError(t, err) + assert.JSONEq(t, input, string(out)) +} + +func TestNormalizeJSON_GlobalAndScopedRulesCoexist(t *testing.T) { + // "state" is a global rule; "type" is parent-scoped. + input := `{ + "state": "ACTIVE", + "contentExtractors": [{"type": "TIKA_CONTENT_EXTRACTION"}] + }` + expected := `{ + "state": "ACTIVE_STATE_ENUM_ACTIVE", + "contentExtractors": [{"type": "CONTENT_EXTRACTOR_TYPE_TIKA_CONTENT_EXTRACTION"}] + }` + + out, err := normalizeJSON([]byte(input), scopedLookup) + require.NoError(t, err) + assert.JSONEq(t, expected, string(out)) +} + +func TestNormalizeJSON_ParentScopedFullCanonicalPassthrough(t *testing.T) { + // Already-prefixed values pass through unchanged + input := `{ + "contentExtractors": [{"type": "CONTENT_EXTRACTOR_TYPE_TIKA_CONTENT_EXTRACTION"}] + }` + + out, err := normalizeJSON([]byte(input), scopedLookup) + require.NoError(t, err) + assert.JSONEq(t, input, string(out)) +} + +func TestNormalizeJSON_ParentScopedCaseInsensitive(t *testing.T) { + input := `{ + "tagProcessors": [{"type": "required_tags"}] + }` + expected := `{ + "tagProcessors": [{"type": "TAG_PROCESSOR_TYPE_REQUIRED_TAGS"}] + }` + + out, err := normalizeJSON([]byte(input), scopedLookup) + require.NoError(t, err) + assert.JSONEq(t, expected, string(out)) +} diff --git a/service/pkg/enumnormalize/middleware.go b/service/pkg/enumnormalize/middleware.go new file mode 100644 index 0000000000..a639192c1b --- /dev/null +++ b/service/pkg/enumnormalize/middleware.go @@ -0,0 +1,73 @@ +package enumnormalize + +import ( + "bytes" + "io" + "net/http" + "strconv" + "strings" +) + +// defaultMaxBodySize is the fallback upper bound on request bodies the +// middleware will read into memory for normalization when no explicit limit is +// provided. Policy API request bodies are small (typically under 10 KB). +const defaultMaxBodySize = 1 << 20 // 1 MB + +// NewMiddleware returns HTTP middleware that normalises shorthand enum string +// values in JSON request bodies for the given RPC paths. Requests that do not +// match (wrong content-type, wrong path) are forwarded unchanged with zero +// overhead. An optional maxBodyBytes sets the upper bound on request body size; +// defaults to 1 MB if omitted or zero. +func NewMiddleware(rules []EnumFieldRule, paths []string, maxBodyBytes ...int64) func(http.Handler) http.Handler { + bodyLimit := int64(defaultMaxBodySize) + if len(maxBodyBytes) > 0 && maxBodyBytes[0] > 0 { + bodyLimit = maxBodyBytes[0] + } + lookup := buildRuleLookup(rules) + + pathSet := make(map[string]struct{}, len(paths)) + for _, p := range paths { + pathSet[p] = struct{}{} + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Only rewrite JSON bodies on matching RPC paths. + if !isJSON(r) || !matchesPath(r, pathSet) { + next.ServeHTTP(w, r) + return + } + + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, bodyLimit)) + if err != nil { + next.ServeHTTP(w, r) + return + } + + normalized, err := normalizeJSON(body, lookup) + if err != nil { + // On normalisation failure, send the original body so + // ConnectRPC can surface its own error. + normalized = body + } + + r.Body = io.NopCloser(bytes.NewReader(normalized)) + r.ContentLength = int64(len(normalized)) + r.Header.Set("Content-Length", strconv.Itoa(len(normalized))) + + next.ServeHTTP(w, r) + }) + } +} + +// isJSON returns true when the request Content-Type indicates a JSON payload +// (application/json or application/connect+json). +func isJSON(r *http.Request) bool { + return strings.Contains(r.Header.Get("Content-Type"), "json") +} + +// matchesPath returns true when the request URL path is in pathSet. +func matchesPath(r *http.Request, pathSet map[string]struct{}) bool { + _, ok := pathSet[r.URL.Path] + return ok +} diff --git a/service/pkg/enumnormalize/middleware_test.go b/service/pkg/enumnormalize/middleware_test.go new file mode 100644 index 0000000000..5f26fe1b0b --- /dev/null +++ b/service/pkg/enumnormalize/middleware_test.go @@ -0,0 +1,142 @@ +package enumnormalize + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testPath = "/policy.subjectmapping.SubjectMappingService/CreateSubjectConditionSet" + +var testRules = []EnumFieldRule{ + {JSONField: "operator", Prefix: "SUBJECT_MAPPING_OPERATOR_ENUM_"}, + {JSONField: "booleanOperator", Prefix: "CONDITION_BOOLEAN_TYPE_ENUM_"}, +} + +// captureHandler records the request body it receives. +type captureHandler struct { + body string +} + +func (h *captureHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { + b, _ := io.ReadAll(r.Body) + h.body = string(b) +} + +func TestMiddleware_NormalizesMatchingJSONRequest(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + body := `{"booleanOperator":"AND","conditions":[{"operator":"IN"}]}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + assert.Contains(t, capture.body, "CONDITION_BOOLEAN_TYPE_ENUM_AND") + assert.Contains(t, capture.body, "SUBJECT_MAPPING_OPERATOR_ENUM_IN") +} + +func TestMiddleware_ConnectJSONContentType(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + body := `{"operator":"NOT_IN"}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/connect+json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + assert.Contains(t, capture.body, "SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN") +} + +func TestMiddleware_NonMatchingPathPassesThrough(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + body := `{"operator":"IN"}` + req := httptest.NewRequest(http.MethodPost, "/policy.attributes.AttributesService/ListAttributes", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + // Should be the original body, not normalized + assert.Equal(t, body, capture.body) +} + +func TestMiddleware_NonJSONContentTypePassesThrough(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + body := `{"operator":"IN"}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/proto") + handler.ServeHTTP(httptest.NewRecorder(), req) + + assert.Equal(t, body, capture.body) +} + +func TestMiddleware_CanonicalNamesUnchanged(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + body := `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_IN"}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + assert.JSONEq(t, body, capture.body) +} + +func TestMiddleware_NumericEnumValuesPassThrough(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + // Numeric enum values (e.g., 1 for IN, 3 for IN_CONTAINS) are valid + // protojson and should pass through the middleware unchanged. + body := `{"booleanOperator":1,"conditions":[{"operator":3}]}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + assert.JSONEq(t, body, capture.body) +} + +func TestMiddleware_ContentLengthUpdated(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + body := `{"operator":"IN"}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + // The normalized body is longer than the original + require.Greater(t, len(capture.body), len(body)) +} + +func TestMiddleware_OversizedBodySkipsNormalization(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + // Build a body that exceeds the default max body size (1 MB). + oversized := `{"operator":"` + strings.Repeat("A", defaultMaxBodySize) + `"}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(oversized)) + req.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + // The middleware should skip normalization on read error and forward the + // request. The downstream handler receives whatever MaxBytesReader yielded + // before the limit — NOT a normalized body. + assert.NotContains(t, capture.body, "SUBJECT_MAPPING_OPERATOR_ENUM_") +} diff --git a/service/pkg/server/options.go b/service/pkg/server/options.go index db3952ceef..9722f5712f 100644 --- a/service/pkg/server/options.go +++ b/service/pkg/server/options.go @@ -2,6 +2,7 @@ package server import ( "context" + "net/http" "connectrpc.com/connect" "github.com/casbin/casbin/v2/persist" @@ -28,6 +29,7 @@ type StartConfig struct { extraConnectInterceptors []connect.Interceptor extraIPCInterceptors []connect.Interceptor + extraHTTPMiddleware []func(http.Handler) http.Handler trustKeyManagerCtxs []trust.NamedKeyManagerCtxFactory @@ -186,6 +188,25 @@ func WithIPCInterceptors(interceptors ...connect.Interceptor) StartOptions { } } +// WithHTTPMiddleware appends HTTP middleware that wraps the ConnectRPC handler. +// Middleware is applied in order, with the last middleware outermost. +// This runs at the HTTP transport layer, before ConnectRPC deserialization, +// making it suitable for request body rewriting (e.g. enum normalization). +// +// Example: +// +// server.Start( +// server.WithHTTPMiddleware( +// enumnormalize.NewMiddleware(rules, paths), +// ), +// ) +func WithHTTPMiddleware(middleware ...func(http.Handler) http.Handler) StartOptions { + return func(c StartConfig) StartConfig { + c.extraHTTPMiddleware = append(c.extraHTTPMiddleware, middleware...) + return c + } +} + // WithTrustKeyManagerFactories option provides factories for creating trust key managers. // Use WithTrustKeyManagerCtxFactories instead. // EXPERIMENTAL diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index f3513a090d..85ea87a69f 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -154,6 +154,7 @@ func Start(f ...StartOptions) error { // Programmatic Connect/IPC interceptors (not config-driven) cfg.Server.ExtraConnectInterceptors = append(cfg.Server.ExtraConnectInterceptors, startConfig.extraConnectInterceptors...) cfg.Server.ExtraIPCInterceptors = append(cfg.Server.ExtraIPCInterceptors, startConfig.extraIPCInterceptors...) + cfg.Server.ExtraHTTPMiddleware = append(cfg.Server.ExtraHTTPMiddleware, startConfig.extraHTTPMiddleware...) // Set Default Policy if startConfig.builtinPolicyOverride != "" { diff --git a/tests-bdd/cukes/steps_enum_shorthand.go b/tests-bdd/cukes/steps_enum_shorthand.go new file mode 100644 index 0000000000..5858ca2cf1 --- /dev/null +++ b/tests-bdd/cukes/steps_enum_shorthand.go @@ -0,0 +1,273 @@ +package cukes + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "strings" + "time" + + "github.com/cucumber/godog" +) + +const bddHTTPTimeout = 15 * time.Second + +type EnumShorthandStepDefinitions struct{} + +// getAccessToken fetches a bearer token from the Keycloak token endpoint using +// the same client credentials the BDD test SDK uses. +func getAccessToken(ctx context.Context, tokenEndpoint string) (string, error) { + data := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"opentdf"}, + "client_secret": {"secret"}, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return "", fmt.Errorf("token request creation failed: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := &http.Client{Timeout: bddHTTPTimeout} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("token request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("token request returned %d: %s", resp.StatusCode, body) + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + } + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return "", fmt.Errorf("failed to decode token response: %w", err) + } + return tokenResp.AccessToken, nil +} + +// postConnectRPC sends a raw JSON body to a ConnectRPC endpoint and returns the +// HTTP status code and response body. +func postConnectRPC(ctx context.Context, endpoint, rpcPath, token, jsonBody string) (int, string, error) { + client := &http.Client{Timeout: bddHTTPTimeout} + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+rpcPath, strings.NewReader(jsonBody)) + if err != nil { + return 0, "", err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := client.Do(req) + if err != nil { + return 0, "", err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return resp.StatusCode, "", err + } + return resp.StatusCode, string(body), nil +} + +// prepareAuthenticatedRequest extracts the platform endpoint and fetches a +// bearer token for raw HTTP requests. This is the common setup shared by all +// shorthand enum e2e step definitions. +func prepareAuthenticatedRequest(ctx context.Context) (*PlatformScenarioContext, string, string, error) { + scenarioContext := GetPlatformScenarioContext(ctx) + scenarioContext.ClearError() + + endpoint := scenarioContext.ScenarioOptions.PlatformEndpoint + tokenEndpoint, err := scenarioContext.SDK.PlatformConfiguration.TokenEndpoint() + if err != nil { + return nil, "", "", fmt.Errorf("failed to get token endpoint: %w", err) + } + + token, err := getAccessToken(ctx, tokenEndpoint) + if err != nil { + return nil, "", "", fmt.Errorf("failed to get access token: %w", err) + } + + return scenarioContext, endpoint, token, nil +} + +// iCreateASubjectConditionSetViaHTTPWithShorthandEnums sends a raw HTTP POST with +// shorthand enum strings and verifies the platform accepts it. +func (s *EnumShorthandStepDefinitions) iCreateASubjectConditionSetViaHTTPWithShorthandEnums(ctx context.Context) (context.Context, error) { + scenarioContext, endpoint, token, err := prepareAuthenticatedRequest(ctx) + if err != nil { + return ctx, err + } + + // Raw JSON with shorthand enum values — this is what the middleware normalizes. + body := `{ + "subjectConditionSet": { + "subjectSets": [{ + "conditionGroups": [{ + "booleanOperator": "AND", + "conditions": [{ + "subjectExternalSelectorValue": ".email", + "operator": "IN_CONTAINS", + "subjectExternalValues": ["@example.com"] + }] + }] + }] + } + }` + + rpcPath := "/policy.subjectmapping.SubjectMappingService/CreateSubjectConditionSet" + statusCode, respBody, err := postConnectRPC(ctx, endpoint, rpcPath, token, body) + if err != nil { + return ctx, fmt.Errorf("HTTP request failed: %w", err) + } + + slog.Debug("shorthand enum e2e response", + slog.Int("status", statusCode), + slog.String("body", respBody)) + + if statusCode != http.StatusOK { + return ctx, fmt.Errorf("expected HTTP 200, got %d: %s", statusCode, respBody) + } + + // Verify the response contains a valid subject condition set ID + var result map[string]any + if err := json.Unmarshal([]byte(respBody), &result); err != nil { + return ctx, fmt.Errorf("failed to parse response: %w", err) + } + scs, ok := result["subjectConditionSet"].(map[string]any) + if !ok || scs["id"] == nil { + return ctx, fmt.Errorf("response missing subjectConditionSet.id: %s", respBody) + } + + scsID, ok := scs["id"].(string) + if !ok { + return ctx, fmt.Errorf("subjectConditionSet.id is not a string: %s", respBody) + } + scenarioContext.RecordObject("shorthand_scs_id", scsID) + return ctx, nil +} + +// iCreateAnAttributeViaHTTPWithShorthandRule sends a raw HTTP POST to create an +// attribute using a shorthand rule type enum. +func (s *EnumShorthandStepDefinitions) iCreateAnAttributeViaHTTPWithShorthandRule(ctx context.Context) (context.Context, error) { + scenarioContext, endpoint, token, err := prepareAuthenticatedRequest(ctx) + if err != nil { + return ctx, err + } + + // Get the namespace ID that was created by the scenario setup + nsID, ok := scenarioContext.GetObject("ns1").(string) + if !ok { + return ctx, errors.New("namespace ns1 not found in scenario context") + } + + // Raw JSON with shorthand rule type — fields are at the top level per the proto definition + body := fmt.Sprintf(`{ + "namespaceId": "%s", + "name": "shorthand_test_attr", + "rule": "ANY_OF", + "values": ["val1", "val2"] + }`, nsID) + + rpcPath := "/policy.attributes.AttributesService/CreateAttribute" + statusCode, respBody, err := postConnectRPC(ctx, endpoint, rpcPath, token, body) + if err != nil { + return ctx, fmt.Errorf("HTTP request failed: %w", err) + } + + slog.Debug("shorthand rule e2e response", + slog.Int("status", statusCode), + slog.String("body", respBody)) + + if statusCode != http.StatusOK { + return ctx, fmt.Errorf("expected HTTP 200, got %d: %s", statusCode, respBody) + } + + // Verify the response contains a valid attribute with the correct rule + var result map[string]any + if err := json.Unmarshal([]byte(respBody), &result); err != nil { + return ctx, fmt.Errorf("failed to parse response: %w", err) + } + attr, ok := result["attribute"].(map[string]any) + if !ok || attr["id"] == nil { + return ctx, fmt.Errorf("response missing attribute.id: %s", respBody) + } + + // Verify the rule was accepted and stored as the canonical name + rule, _ := attr["rule"].(string) + if rule != "ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF" { + return ctx, fmt.Errorf("expected rule ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF, got %s", rule) + } + + return ctx, nil +} + +// iCreateASubjectConditionSetViaHTTPWithMixedEnumFormats verifies that a request +// mixing shorthand and canonical enum names works correctly. +func (s *EnumShorthandStepDefinitions) iCreateASubjectConditionSetViaHTTPWithMixedEnumFormats(ctx context.Context) (context.Context, error) { + _, endpoint, token, err := prepareAuthenticatedRequest(ctx) + if err != nil { + return ctx, err + } + + // Mix shorthand and canonical names in the same request + body := `{ + "subjectConditionSet": { + "subjectSets": [{ + "conditionGroups": [{ + "booleanOperator": "CONDITION_BOOLEAN_TYPE_ENUM_AND", + "conditions": [ + { + "subjectExternalSelectorValue": ".email", + "operator": "IN", + "subjectExternalValues": ["@test.com"] + }, + { + "subjectExternalSelectorValue": ".role", + "operator": "SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN", + "subjectExternalValues": ["guest"] + } + ] + }] + }] + } + }` + + rpcPath := "/policy.subjectmapping.SubjectMappingService/CreateSubjectConditionSet" + statusCode, respBody, err := postConnectRPC(ctx, endpoint, rpcPath, token, body) + if err != nil { + return ctx, fmt.Errorf("HTTP request failed: %w", err) + } + + if statusCode != http.StatusOK { + return ctx, fmt.Errorf("expected HTTP 200, got %d: %s", statusCode, respBody) + } + + var result map[string]any + if err := json.Unmarshal([]byte(respBody), &result); err != nil { + return ctx, fmt.Errorf("failed to parse response: %w", err) + } + scs, ok := result["subjectConditionSet"].(map[string]any) + if !ok || scs["id"] == nil { + return ctx, fmt.Errorf("response missing subjectConditionSet.id: %s", respBody) + } + + return ctx, nil +} + +func RegisterEnumShorthandStepDefinitions(ctx *godog.ScenarioContext) { + steps := &EnumShorthandStepDefinitions{} + ctx.Step(`^I create a subject condition set via HTTP with shorthand enums$`, steps.iCreateASubjectConditionSetViaHTTPWithShorthandEnums) + ctx.Step(`^I create an attribute via HTTP with shorthand rule type$`, steps.iCreateAnAttributeViaHTTPWithShorthandRule) + ctx.Step(`^I create a subject condition set via HTTP with mixed enum formats$`, steps.iCreateASubjectConditionSetViaHTTPWithMixedEnumFormats) +} diff --git a/tests-bdd/features/shorthand-enums.feature b/tests-bdd/features/shorthand-enums.feature new file mode 100644 index 0000000000..274667dec1 --- /dev/null +++ b/tests-bdd/features/shorthand-enums.feature @@ -0,0 +1,18 @@ +@shorthand-enums +Feature: Shorthand Enum Names E2E + Verify that the platform accepts shorthand enum names (e.g., "IN", "AND", + "ANY_OF") in raw HTTP JSON requests. These tests bypass the SDK and send + raw ConnectRPC JSON to prove the normalization middleware works end-to-end. + + Background: + Given an empty local platform + And I submit a request to create a namespace with name "shorthandenums.io" and reference id "ns1" + + Scenario: Create subject condition set with shorthand operator and boolean enums + When I create a subject condition set via HTTP with shorthand enums + + Scenario: Create attribute with shorthand rule type enum + When I create an attribute via HTTP with shorthand rule type + + Scenario: Create subject condition set with mixed shorthand and canonical enum names + When I create a subject condition set via HTTP with mixed enum formats diff --git a/tests-bdd/platform_test.go b/tests-bdd/platform_test.go index f1251c8d8f..6acafef0a1 100644 --- a/tests-bdd/platform_test.go +++ b/tests-bdd/platform_test.go @@ -108,6 +108,7 @@ func runTests() int { cukes.RegisterSmokeStepDefinitions(ctx, platformCukesContext) cukes.RegisterAuthorizationStepDefinitions(ctx) cukes.RegisterSubjectMappingsStepsDefinitions(ctx) + cukes.RegisterEnumShorthandStepDefinitions(ctx) cukes.RegisterRegisteredResourcesStepDefinitions(ctx) cukes.RegisterObligationsStepDefinitions(ctx, platformCukesContext) platformCukesContext.InitializeScenario(ctx)