diff --git a/docs/filters/http/ai/openai_responses_model_rewrite.md b/docs/filters/http/ai/openai_responses_model_rewrite.md new file mode 100644 index 00000000..3338ca5d --- /dev/null +++ b/docs/filters/http/ai/openai_responses_model_rewrite.md @@ -0,0 +1,46 @@ + + + +# `openai_responses_model_rewrite` + +Rewrites the `model` field in Responses API request bodies. + +Requires Cargo feature: `ai-inference`. + +## Configuration + +| Field | Type | Required | Description | +|-------|------|---------|-------------| +| `default_model` | string | no | Model name to inject when the request body has no `model` field or when the field is `null`. | +| `headers` | ModelRewriteHeaders | no | Header names for promoted model values. | +| `headers.effective_model` | string | no | Header name for the effective (post-rewrite) model value. | +| `headers.original_model` | string | no | Header name for the original (pre-rewrite) model value. | +| `max_body_bytes` | usize | no | Maximum request body size to buffer before parsing. | +| `model_aliases` | object | no | Map from client-facing model names to backend model names. | +| `on_invalid` | `continue` \| `reject` | no | Behavior when the body is not valid JSON. | + +## Examples + +### Example 1 + +```yaml +filter: openai_responses_model_rewrite +default_model: "llama-3.3-70b" +model_aliases: + codex-mini-latest: "llama-3.3-70b" + gpt-4.1-mini: "qwen-2.5-72b" +``` + +### Example 2 + +```yaml +filter: openai_responses_model_rewrite +default_model: "llama-3.3-70b" +model_aliases: + codex-mini-latest: "llama-3.3-70b" +max_body_bytes: 10485760 +on_invalid: continue +headers: + effective_model: x-praxis-ai-effective-model + original_model: x-praxis-ai-original-model +``` diff --git a/docs/filters/reference.md b/docs/filters/reference.md index 7bbdd90e..fb89bb4c 100644 --- a/docs/filters/reference.md +++ b/docs/filters/reference.md @@ -21,6 +21,7 @@ Built-in filters organized by protocol and category. | [`model_to_header`](http/ai/model_to_header.md) | `ai-inference` | Promotes the JSON `"model"` field from the request body to a request header. | | [`openai_response_store`](http/ai/openai_response_store.md) | `ai-inference` | Persists non-streaming Responses API responses to the configured response store backend. | | [`openai_responses_format`](http/ai/openai_responses_format.md) | `ai-inference` | Classifies AI API request bodies and promotes routing facts to headers, metadata, and filter results without mutating the body. | +| [`openai_responses_model_rewrite`](http/ai/openai_responses_model_rewrite.md) | `ai-inference` | Rewrites the `model` field in Responses API request bodies. | | [`openai_responses_validate`](http/ai/openai_responses_validate.md) | `ai-inference` | Validates and enriches Responses API requests. | | [`prompt_enrich`](http/ai/prompt_enrich.md) | `ai-inference` | Injects statically configured messages into the `messages` array of OpenAI-compatible chat completion request bodies. | | [`token_usage_headers`](http/ai/token_usage_headers.md) | - | Injects `Praxis-Token-Input`, `Praxis-Token-Output`, and `Praxis-Token-Total` headers into downstream responses when token usage data is present in [`filter_metadata`]. | diff --git a/examples/README.md b/examples/README.md index cfa8e2c3..0a8cb8f6 100644 --- a/examples/README.md +++ b/examples/README.md @@ -35,6 +35,7 @@ page. | [model-to-header-routing.yaml](configs/ai/model-to-header-routing.yaml) | Routes LLM API requests to different backends based on the "model" field in the JSON request body | | [format-routing.yaml](configs/ai/openai/responses/format-routing.yaml) | Routes AI API traffic by detected body format | | [full-flow.yaml](configs/ai/openai/responses/full-flow.yaml) | Combines format classification, request validation, and backend routing into a single pipeline | +| [model-rewrite.yaml](configs/ai/openai/responses/model-rewrite.yaml) | Rewrites or injects the top-level `model` field in Responses API request bodies before forwarding to the inference backend | | [request-validate.yaml](configs/ai/openai/responses/request-validate.yaml) | Validates Responses API requests and rejects invalid parameter combinations | | [response-store.yaml](configs/ai/openai/responses/response-store.yaml) | Persists non-streaming Responses API responses to a database and serves stored data via GET endpoints and handles DELETE /v1/responses/{id} locally | | [responses-routing.yaml](configs/ai/openai/responses/responses-routing.yaml) | Routes Responses API traffic by detected mode | diff --git a/examples/configs/ai/openai/responses/model-rewrite.yaml b/examples/configs/ai/openai/responses/model-rewrite.yaml new file mode 100644 index 00000000..77e537de --- /dev/null +++ b/examples/configs/ai/openai/responses/model-rewrite.yaml @@ -0,0 +1,65 @@ +# Responses Model Rewrite +# +# Rewrites or injects the top-level `model` field in Responses API +# request bodies before forwarding to the inference backend. +# +# This pipeline uses `openai_responses_format` to classify the +# request, then `openai_responses_model_rewrite` to apply alias +# mapping. The router uses the effective model header to select +# the backend cluster, so routing reflects the rewritten model, +# not the original client-facing name. +# +# Use case: Codex or other Responses API clients send +# `model: "codex-mini-latest"` and the proxy transparently +# rewrites it to a locally-hosted model before forwarding. +# +# Requires the ai-inference feature: +# cargo build -p praxis --features ai-inference + +listeners: + - name: ai-gateway + address: "127.0.0.1:8080" + filter_chains: [model-rewrite-pipeline] + +filter_chains: + - name: model-rewrite-pipeline + filters: + - filter: openai_responses_format + on_invalid: continue + headers: + format: x-praxis-ai-format + model: x-praxis-ai-model + + - filter: openai_responses_model_rewrite + default_model: "llama-3.3-70b" + model_aliases: + codex-mini-latest: "llama-3.3-70b" + gpt-4.1-mini: "qwen-2.5-72b" + headers: + effective_model: x-praxis-ai-effective-model + original_model: x-praxis-ai-original-model + + - filter: router + routes: + - path: "/v1/responses" + headers: + x-praxis-ai-effective-model: "llama-3.3-70b" + cluster: "llama-backend" + - path: "/v1/responses" + headers: + x-praxis-ai-effective-model: "qwen-2.5-72b" + cluster: "qwen-backend" + - path_prefix: "/" + cluster: "default-backend" + + - filter: load_balancer + clusters: + - name: "llama-backend" + endpoints: + - "127.0.0.1:3001" + - name: "qwen-backend" + endpoints: + - "127.0.0.1:3002" + - name: "default-backend" + endpoints: + - "127.0.0.1:3003" diff --git a/filter/src/builtins/http/ai/classifier/mod.rs b/filter/src/builtins/http/ai/classifier/mod.rs index 6f28a694..51d4931f 100644 --- a/filter/src/builtins/http/ai/classifier/mod.rs +++ b/filter/src/builtins/http/ai/classifier/mod.rs @@ -87,7 +87,7 @@ pub(crate) struct ClassifiedRequest { /// - `POST /v1/responses/compact` /// - `DELETE /v1/responses/{id}` pub(crate) fn is_responses_path(method: &http::Method, path: &str) -> bool { - let path = path.strip_suffix('/').filter(|p| !p.is_empty()).unwrap_or(path); + let path = normalize_trailing_slash(path); let segments: Vec<&str> = path.split('/').collect(); match (method, segments.as_slice()) { @@ -108,6 +108,14 @@ pub(crate) fn is_responses_path(method: &http::Method, path: &str) -> bool { } } +/// Check whether a method + path pair is the Responses API create endpoint. +/// +/// Returns `true` only for `POST /v1/responses` (with optional trailing slash). +/// Sub-resource POSTs like `/v1/responses/{id}/cancel` return `false`. +pub(crate) fn is_responses_create(method: &http::Method, path: &str) -> bool { + method == http::Method::POST && normalize_trailing_slash(path) == "/v1/responses" +} + // ----------------------------------------------------------------------------- // Body Classification // ----------------------------------------------------------------------------- @@ -212,6 +220,11 @@ fn has_anthropic_signals(obj: &serde_json::Map) -> bo // Private Utilities // ----------------------------------------------------------------------------- +/// Strip a single trailing slash unless the path is the root `/`. +fn normalize_trailing_slash(path: &str) -> &str { + path.strip_suffix('/').filter(|p| !p.is_empty()).unwrap_or(path) +} + /// Build a result with no extracted facts. pub(crate) fn empty_result(format: AiRequestFormat) -> ClassifiedRequest { ClassifiedRequest { @@ -859,6 +872,66 @@ mod tests { ); } + // ------------------------------------------------------------------------- + // Create-Endpoint Classification + // ------------------------------------------------------------------------- + + #[test] + fn create_matches_post_v1_responses() { + assert!( + is_responses_create(&http::Method::POST, "/v1/responses"), + "POST /v1/responses should match create" + ); + } + + #[test] + fn create_matches_post_v1_responses_trailing_slash() { + assert!( + is_responses_create(&http::Method::POST, "/v1/responses/"), + "POST /v1/responses/ should match create" + ); + } + + #[test] + fn create_rejects_get() { + assert!( + !is_responses_create(&http::Method::GET, "/v1/responses"), + "GET /v1/responses should not match create" + ); + } + + #[test] + fn create_rejects_cancel_subresource() { + assert!( + !is_responses_create(&http::Method::POST, "/v1/responses/resp_abc/cancel"), + "POST /v1/responses/{{id}}/cancel should not match create" + ); + } + + #[test] + fn create_rejects_input_tokens() { + assert!( + !is_responses_create(&http::Method::POST, "/v1/responses/input_tokens"), + "POST /v1/responses/input_tokens should not match create" + ); + } + + #[test] + fn create_rejects_compact() { + assert!( + !is_responses_create(&http::Method::POST, "/v1/responses/compact"), + "POST /v1/responses/compact should not match create" + ); + } + + #[test] + fn create_rejects_chat_completions() { + assert!( + !is_responses_create(&http::Method::POST, "/v1/chat/completions"), + "POST /v1/chat/completions should not match create" + ); + } + #[test] fn previous_response_id_only_classifies_as_responses() { let body = br#"{"model":"gpt-4.1","previous_response_id":"resp_abc"}"#; diff --git a/filter/src/builtins/http/ai/mod.rs b/filter/src/builtins/http/ai/mod.rs index fdd16065..48bc0f83 100644 --- a/filter/src/builtins/http/ai/mod.rs +++ b/filter/src/builtins/http/ai/mod.rs @@ -45,6 +45,8 @@ pub use guardrails::AiGuardrailsFilter; pub use inference::ModelToHeaderFilter; pub(crate) use on_invalid::OnInvalidBehavior; #[cfg(feature = "ai-inference")] +pub use openai::ModelRewriteFilter; +#[cfg(feature = "ai-inference")] pub use openai::OpenaiResponsesValidateFilter; #[cfg(feature = "ai-inference")] pub use openai::ResponseStoreFilter; diff --git a/filter/src/builtins/http/ai/openai/mod.rs b/filter/src/builtins/http/ai/openai/mod.rs index b468ac04..75e41876 100644 --- a/filter/src/builtins/http/ai/openai/mod.rs +++ b/filter/src/builtins/http/ai/openai/mod.rs @@ -5,6 +5,8 @@ pub(crate) mod responses; +#[cfg(feature = "ai-inference")] +pub use responses::ModelRewriteFilter; #[cfg(feature = "ai-inference")] pub use responses::OpenaiResponsesValidateFilter; pub use responses::{ResponseStoreFilter, ResponsesFormatFilter}; diff --git a/filter/src/builtins/http/ai/openai/responses/mod.rs b/filter/src/builtins/http/ai/openai/responses/mod.rs index bc8b0ad0..c49a9116 100644 --- a/filter/src/builtins/http/ai/openai/responses/mod.rs +++ b/filter/src/builtins/http/ai/openai/responses/mod.rs @@ -18,6 +18,8 @@ //! to validate parameter combinations and extract additional fields. mod config; +#[cfg(feature = "ai-inference")] +pub(crate) mod model_rewrite; #[expect(clippy::allow_attributes, reason = "dead_code expect unfulfilled on modules")] #[allow( dead_code, @@ -26,6 +28,8 @@ mod config; pub(crate) mod state; pub(crate) mod store; +#[cfg(feature = "ai-inference")] +pub use model_rewrite::ModelRewriteFilter; pub use store::ResponseStoreFilter; #[cfg(test)] diff --git a/filter/src/builtins/http/ai/openai/responses/model_rewrite/config.rs b/filter/src/builtins/http/ai/openai/responses/model_rewrite/config.rs new file mode 100644 index 00000000..077fdcd9 --- /dev/null +++ b/filter/src/builtins/http/ai/openai/responses/model_rewrite/config.rs @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 Praxis Contributors + +//! Deserialized YAML configuration types for the model rewrite filter. + +use std::collections::HashMap; + +use serde::Deserialize; + +use crate::{ + FilterError, body::DEFAULT_JSON_BODY_MAX_BYTES, builtins::http::ai::config_validation::validate_max_body_bytes, +}; + +// ----------------------------------------------------------------------------- +// ModelRewriteConfig +// ----------------------------------------------------------------------------- + +/// Deserialized YAML config for the model rewrite filter. +/// +/// ```yaml +/// filter: openai_responses_model_rewrite +/// default_model: "llama-3.3-70b" +/// model_aliases: +/// codex-mini-latest: "llama-3.3-70b" +/// gpt-4.1-mini: "qwen-2.5-72b" +/// max_body_bytes: 10485760 +/// on_invalid: continue +/// headers: +/// effective_model: x-praxis-ai-effective-model +/// original_model: x-praxis-ai-original-model +/// ``` +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub(super) struct ModelRewriteConfig { + /// Model name to inject when the request body has no `model` + /// field or when the field is `null`. + #[serde(default)] + pub default_model: Option, + + /// Header names for promoted model values. + #[serde(default)] + pub headers: ModelRewriteHeaders, + + /// Maximum request body size to buffer before parsing. + #[serde(default = "default_max_body_bytes")] + pub max_body_bytes: usize, + + /// Map from client-facing model names to backend model names. + #[serde(default)] + pub model_aliases: HashMap, + + /// Behavior when the body is not valid JSON. + #[serde(default)] + pub on_invalid: OnInvalidBehavior, +} + +/// Default for `max_body_bytes`. +fn default_max_body_bytes() -> usize { + DEFAULT_JSON_BODY_MAX_BYTES +} + +// ----------------------------------------------------------------------------- +// ModelRewriteHeaders +// ----------------------------------------------------------------------------- + +/// Configurable header names for promoted model values. +#[derive(Debug, Clone, Deserialize)] +#[serde(deny_unknown_fields)] +pub(super) struct ModelRewriteHeaders { + /// Header name for the effective (post-rewrite) model value. + #[serde(default = "default_effective_model_header")] + pub effective_model: Option, + + /// Header name for the original (pre-rewrite) model value. + #[serde(default = "default_original_model_header")] + pub original_model: Option, +} + +impl Default for ModelRewriteHeaders { + fn default() -> Self { + Self { + effective_model: default_effective_model_header(), + original_model: default_original_model_header(), + } + } +} + +/// Default effective model header name. +#[expect( + clippy::unnecessary_wraps, + reason = "serde default functions require Option return type" +)] +fn default_effective_model_header() -> Option { + Some("x-praxis-ai-effective-model".to_owned()) +} + +/// Default original model header name. +#[expect( + clippy::unnecessary_wraps, + reason = "serde default functions require Option return type" +)] +fn default_original_model_header() -> Option { + Some("x-praxis-ai-original-model".to_owned()) +} + +// ----------------------------------------------------------------------------- +// OnInvalidBehavior +// ----------------------------------------------------------------------------- + +/// Behavior when the request body cannot be parsed as JSON. +#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub(super) enum OnInvalidBehavior { + /// Pass the original body through unchanged. + #[default] + Continue, + + /// Return HTTP 400. + Reject, +} + +// ----------------------------------------------------------------------------- +// Validation +// ----------------------------------------------------------------------------- + +/// Validate a parsed config, returning an error for invalid values. +/// +/// # Errors +/// +/// Returns [`FilterError`] when the config is invalid. +/// +/// [`FilterError`]: crate::FilterError +pub(super) fn validate_config(cfg: &ModelRewriteConfig) -> Result<(), FilterError> { + if cfg.default_model.is_none() && cfg.model_aliases.is_empty() { + return Err( + "openai_responses_model_rewrite: at least one of 'default_model' or 'model_aliases' must be configured" + .into(), + ); + } + + if let Some(dm) = &cfg.default_model + && dm.trim().is_empty() + { + return Err("openai_responses_model_rewrite: 'default_model' must not be empty".into()); + } + + validate_aliases(&cfg.model_aliases)?; + validate_max_body_bytes("openai_responses_model_rewrite", cfg.max_body_bytes)?; + validate_header_name("effective_model", cfg.headers.effective_model.as_deref())?; + validate_header_name("original_model", cfg.headers.original_model.as_deref())?; + + Ok(()) +} + +/// Validate alias map entries. +fn validate_aliases(aliases: &HashMap) -> Result<(), FilterError> { + for (source, target) in aliases { + if source.is_empty() { + return Err("openai_responses_model_rewrite: alias source name must not be empty".into()); + } + if target.is_empty() { + return Err( + format!("openai_responses_model_rewrite: alias target for '{source}' must not be empty").into(), + ); + } + } + Ok(()) +} + +/// Validate a configured header name using the HTTP header-name parser. +fn validate_header_name(field: &str, name: Option<&str>) -> Result<(), FilterError> { + let Some(name) = name else { + return Ok(()); + }; + if name.is_empty() { + return Err(format!("openai_responses_model_rewrite: '{field}' header name must not be empty").into()); + } + if http::HeaderName::from_bytes(name.as_bytes()).is_err() { + return Err( + format!("openai_responses_model_rewrite: '{field}' header name is not a valid HTTP header name").into(), + ); + } + Ok(()) +} diff --git a/filter/src/builtins/http/ai/openai/responses/model_rewrite/mod.rs b/filter/src/builtins/http/ai/openai/responses/model_rewrite/mod.rs new file mode 100644 index 00000000..d54857e7 --- /dev/null +++ b/filter/src/builtins/http/ai/openai/responses/model_rewrite/mod.rs @@ -0,0 +1,466 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 Praxis Contributors + +//! Model rewrite filter for `OpenAI` Responses API requests. +//! +//! Rewrites the top-level `model` field in `POST /v1/responses` +//! request bodies using a configured alias map. When the `model` +//! field is missing or null and a `default_model` is configured, +//! injects the default. Preserves every other field semantically, +//! including `input`, `instructions`, `tools`, and unknown fields. +//! Rewritten requests are re-serialized as JSON, so original +//! whitespace and byte-level object key order are not preserved. +//! +//! Gates on the request path (`POST /v1/responses` exactly), not +//! on classifier metadata. This ensures `on_invalid: reject` fires +//! for malformed JSON on the create endpoint even when the +//! classifier could not classify the body. + +mod config; + +#[cfg(test)] +#[expect(clippy::allow_attributes, reason = "blanket test suppressions")] +#[allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::needless_raw_strings, + clippy::needless_raw_string_hashes, + clippy::too_many_lines, + reason = "tests" +)] +mod tests; + +use std::{borrow::Cow, collections::HashMap}; + +use async_trait::async_trait; +use bytes::Bytes; +use tracing::{debug, trace, warn}; + +use self::config::{ModelRewriteConfig, OnInvalidBehavior, validate_config}; +use crate::{ + FilterAction, FilterError, Rejection, + body::{BodyAccess, BodyMode}, + builtins::http::{ai::classifier::is_responses_create, value_safety::is_safe_promoted_value}, + factory::parse_filter_config, + filter::{HttpFilter, HttpFilterContext}, +}; + +// ----------------------------------------------------------------------------- +// Constants +// ----------------------------------------------------------------------------- + +/// Maximum length of a body-derived value promoted to headers or filter results. +const MAX_PROMOTED_VALUE_LEN: usize = 256; + +// ----------------------------------------------------------------------------- +// ModelRewriteFilter +// ----------------------------------------------------------------------------- + +/// Rewrites the `model` field in Responses API request bodies. +/// +/// # YAML +/// +/// ```yaml +/// filter: openai_responses_model_rewrite +/// default_model: "llama-3.3-70b" +/// model_aliases: +/// codex-mini-latest: "llama-3.3-70b" +/// gpt-4.1-mini: "qwen-2.5-72b" +/// ``` +/// +/// # Full YAML +/// +/// ```yaml +/// filter: openai_responses_model_rewrite +/// default_model: "llama-3.3-70b" +/// model_aliases: +/// codex-mini-latest: "llama-3.3-70b" +/// max_body_bytes: 10485760 +/// on_invalid: continue +/// headers: +/// effective_model: x-praxis-ai-effective-model +/// original_model: x-praxis-ai-original-model +/// ``` +pub struct ModelRewriteFilter { + /// Model name to inject when absent or null. + default_model: Option, + + /// Configurable header names for promoted model values. + headers: config::ModelRewriteHeaders, + + /// Maximum request body size for `StreamBuffer` mode. + max_body_bytes: usize, + + /// Map from client-facing model names to backend model names. + model_aliases: HashMap, + + /// Behavior when the body is not valid JSON. + on_invalid: OnInvalidBehavior, +} + +impl ModelRewriteFilter { + /// Create a filter from parsed YAML config. + /// + /// # Errors + /// + /// Returns [`FilterError`] if the YAML config is invalid. + /// + /// [`FilterError`]: crate::FilterError + pub fn from_config(config: &serde_yaml::Value) -> Result, FilterError> { + let cfg: ModelRewriteConfig = parse_filter_config("openai_responses_model_rewrite", config)?; + validate_config(&cfg)?; + Ok(Box::new(Self { + default_model: cfg.default_model, + headers: cfg.headers, + max_body_bytes: cfg.max_body_bytes, + model_aliases: cfg.model_aliases, + on_invalid: cfg.on_invalid, + })) + } + + /// Parse, rewrite, and re-serialize the request body. + fn rewrite_body( + &self, + ctx: &mut HttpFilterContext<'_>, + body: &mut Option, + ) -> Result { + let Some(raw) = body.as_ref() else { + return Ok(FilterAction::Continue); + }; + let mut value: serde_json::Value = match serde_json::from_slice(raw) { + Ok(v) => v, + Err(_) => return Ok(invalid_body_action(self.on_invalid)), + }; + + let Some(obj) = value.as_object_mut() else { + return Ok(invalid_body_action(self.on_invalid)); + }; + + let result = apply_rewrite(obj, &self.model_aliases, self.default_model.as_deref()); + promote_facts(ctx, &result, &self.headers); + + if !result.mutated { + return Ok(FilterAction::Continue); + } + + serialize_and_update(ctx, body, &value, &result) + } +} + +#[async_trait] +impl HttpFilter for ModelRewriteFilter { + fn name(&self) -> &'static str { + "openai_responses_model_rewrite" + } + + fn request_body_access(&self) -> BodyAccess { + BodyAccess::ReadWrite + } + + fn request_body_mode(&self) -> BodyMode { + BodyMode::StreamBuffer { + max_bytes: Some(self.max_body_bytes), + } + } + + async fn on_request(&self, ctx: &mut HttpFilterContext<'_>) -> Result { + // Repopulate filter results from metadata written during body + // pre-read. Branch chains evaluate after on_request, and a + // preceding filter's branch evaluation clears filter_results + // before this filter's branches fire. + repopulate_filter_results(ctx); + Ok(FilterAction::Continue) + } + + async fn on_request_body( + &self, + ctx: &mut HttpFilterContext<'_>, + body: &mut Option, + end_of_stream: bool, + ) -> Result { + if !end_of_stream { + return Ok(FilterAction::Continue); + } + + if !is_responses_create(&ctx.request.method, ctx.request.uri.path()) { + trace!("skipping non-create request"); + return Ok(FilterAction::Continue); + } + + self.rewrite_body(ctx, body) + } +} + +// ----------------------------------------------------------------------------- +// Rewrite Logic +// ----------------------------------------------------------------------------- + +/// Outcome of a model rewrite attempt. +#[expect(clippy::struct_excessive_bools, reason = "independent decision flags")] +struct RewriteResult { + /// Whether the default model was injected. + default_injected: bool, + + /// Effective model value after alias/default resolution. + effective_model: String, + + /// Whether the model value was changed in the body. + mutated: bool, + + /// Original model value before rewrite, if present. + original_model: Option, + + /// Whether an alias changed the model. + rewritten: bool, +} + +/// Apply alias and default model policy to the JSON object. +/// +/// Three cases: +/// - Missing or null `model` → inject `default_model` if configured. +/// - String `model` → apply alias mapping or pass through. +/// - Non-string type (number, object, array, bool) → no-op; let the backend validate. The proxy should not silently +/// replace a non-string value with its own default. +fn apply_rewrite( + obj: &mut serde_json::Map, + aliases: &HashMap, + default_model: Option<&str>, +) -> RewriteResult { + match obj.get("model") { + Some(serde_json::Value::String(model)) => apply_alias(obj, aliases, model.clone()), + Some(serde_json::Value::Null) | None => apply_default(obj, default_model), + Some(_) => noop_result(), + } +} + +/// Apply alias mapping when a model field is present. +fn apply_alias( + obj: &mut serde_json::Map, + aliases: &HashMap, + model: String, +) -> RewriteResult { + if let Some(target) = aliases.get(&model) { + let effective = target.clone(); + obj.insert("model".to_owned(), serde_json::Value::String(effective.clone())); + RewriteResult { + default_injected: false, + effective_model: effective, + mutated: true, + original_model: Some(model), + rewritten: true, + } + } else { + RewriteResult { + default_injected: false, + effective_model: model.clone(), + mutated: false, + original_model: Some(model), + rewritten: false, + } + } +} + +/// Inject default model when the model field is missing or null. +fn apply_default(obj: &mut serde_json::Map, default_model: Option<&str>) -> RewriteResult { + if let Some(dm) = default_model { + let effective = dm.to_owned(); + obj.insert("model".to_owned(), serde_json::Value::String(effective.clone())); + RewriteResult { + default_injected: true, + effective_model: effective, + mutated: true, + original_model: None, + rewritten: false, + } + } else { + RewriteResult { + default_injected: false, + effective_model: String::new(), + mutated: false, + original_model: None, + rewritten: false, + } + } +} + +/// Build a no-op result for non-string model values. +fn noop_result() -> RewriteResult { + RewriteResult { + default_injected: false, + effective_model: String::new(), + mutated: false, + original_model: None, + rewritten: false, + } +} + +/// Serialize the mutated body, update content-length, and log. +fn serialize_and_update( + ctx: &mut HttpFilterContext<'_>, + body: &mut Option, + value: &serde_json::Value, + result: &RewriteResult, +) -> Result { + let serialized = serde_json::to_vec(value).map_err(|e| -> FilterError { + format!("openai_responses_model_rewrite: failed to re-serialize rewritten request body: {e}").into() + })?; + + let len = serialized.len(); + *body = Some(Bytes::from(serialized)); + + ctx.extra_request_headers + .push((Cow::Borrowed("content-length"), len.to_string())); + + debug!( + original = ?result.original_model, + effective = %result.effective_model, + "model rewritten" + ); + + Ok(FilterAction::Continue) +} + +// ----------------------------------------------------------------------------- +// Promotion Helpers +// ----------------------------------------------------------------------------- + +/// Promote rewrite facts to durable metadata, request headers, +/// and filter results. +/// +/// Filter results are also repopulated in `on_request` because +/// a preceding filter's branch evaluation clears them before this +/// filter's branches fire. +fn promote_facts(ctx: &mut HttpFilterContext<'_>, result: &RewriteResult, headers: &config::ModelRewriteHeaders) { + write_metadata(ctx, result); + promote_headers(ctx, result, headers); + set_filter_results_from_result(ctx, result); +} + +/// Write durable metadata for downstream filters. +/// +/// Applies the same safety and length policy used for header +/// promotion so request-controlled model values containing +/// control characters are not written to metadata. +fn write_metadata(ctx: &mut HttpFilterContext<'_>, result: &RewriteResult) { + if let Some(orig) = &result.original_model + && is_safe_promoted_value(orig) + && orig.len() <= MAX_PROMOTED_VALUE_LEN + { + ctx.set_metadata("openai_responses_model_rewrite.original_model", orig.clone()); + } + if !result.effective_model.is_empty() + && is_safe_promoted_value(&result.effective_model) + && result.effective_model.len() <= MAX_PROMOTED_VALUE_LEN + { + ctx.set_metadata( + "openai_responses_model_rewrite.effective_model", + result.effective_model.clone(), + ); + } + if result.rewritten { + ctx.set_metadata("openai_responses_model_rewrite.rewritten", "true"); + } + if result.default_injected { + ctx.set_metadata("openai_responses_model_rewrite.default_injected", "true"); + } +} + +/// Promote model values to configurable request headers. +fn promote_headers(ctx: &mut HttpFilterContext<'_>, result: &RewriteResult, headers: &config::ModelRewriteHeaders) { + if let Some(header) = &headers.effective_model + && !result.effective_model.is_empty() + && is_safe_promoted_value(&result.effective_model) + && result.effective_model.len() <= MAX_PROMOTED_VALUE_LEN + { + ctx.extra_request_headers + .push((Cow::Owned(header.clone()), result.effective_model.clone())); + } + + if let Some(header) = &headers.original_model + && let Some(orig) = &result.original_model + && is_safe_promoted_value(orig) + && orig.len() <= MAX_PROMOTED_VALUE_LEN + { + ctx.extra_request_headers + .push((Cow::Owned(header.clone()), orig.clone())); + } +} + +/// Set filter results from a [`RewriteResult`] (body pre-read phase). +fn set_filter_results_from_result(ctx: &mut HttpFilterContext<'_>, result: &RewriteResult) { + let results = ctx.filter_results.entry("openai_responses_model_rewrite").or_default(); + + if !result.effective_model.is_empty() + && is_safe_promoted_value(&result.effective_model) + && result.effective_model.len() <= MAX_PROMOTED_VALUE_LEN + { + set_filter_result(results, "effective_model", result.effective_model.clone()); + } + if result.rewritten { + set_filter_result(results, "rewritten", "true"); + } + if result.default_injected { + set_filter_result(results, "default_injected", "true"); + } +} + +/// Repopulate filter results from durable metadata. +/// +/// Branch chains evaluate after `on_request`, not after body +/// pre-read. A preceding filter's branch evaluation clears +/// `filter_results` before this filter's branches fire. This +/// function rebuilds results from the metadata written during +/// body pre-read so branches on this filter work correctly. +fn repopulate_filter_results(ctx: &mut HttpFilterContext<'_>) { + let effective = ctx + .get_metadata("openai_responses_model_rewrite.effective_model") + .map(str::to_owned); + let rewritten = ctx.get_metadata("openai_responses_model_rewrite.rewritten") == Some("true"); + let default_injected = ctx.get_metadata("openai_responses_model_rewrite.default_injected") == Some("true"); + + if effective.is_none() && !rewritten && !default_injected { + return; + } + + let results = ctx.filter_results.entry("openai_responses_model_rewrite").or_default(); + if let Some(eff) = effective { + set_filter_result(results, "effective_model", eff); + } + if rewritten { + set_filter_result(results, "rewritten", "true"); + } + if default_injected { + set_filter_result(results, "default_injected", "true"); + } +} + +// ----------------------------------------------------------------------------- +// Private Utilities +// ----------------------------------------------------------------------------- + +/// Set a filter result and log validation failures. +fn set_filter_result( + results: &mut crate::results::FilterResultSet, + key: &'static str, + value: impl Into>, +) { + if let Err(err) = results.set(key, value) { + warn!(error = %err, key, "failed to set model rewrite filter result"); + } +} + +/// Map [`OnInvalidBehavior`] to the appropriate [`FilterAction`]. +fn invalid_body_action(behavior: OnInvalidBehavior) -> FilterAction { + match behavior { + OnInvalidBehavior::Continue => FilterAction::Continue, + OnInvalidBehavior::Reject => FilterAction::Reject( + Rejection::status(400) + .with_header("content-type", "application/json") + .with_body(Bytes::from( + r#"{"error":{"message":"invalid JSON body","type":"invalid_request_error"}}"#, + )), + ), + } +} diff --git a/filter/src/builtins/http/ai/openai/responses/model_rewrite/tests.rs b/filter/src/builtins/http/ai/openai/responses/model_rewrite/tests.rs new file mode 100644 index 00000000..1cd62887 --- /dev/null +++ b/filter/src/builtins/http/ai/openai/responses/model_rewrite/tests.rs @@ -0,0 +1,971 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 Praxis Contributors + +use bytes::Bytes; + +use super::*; + +// ----------------------------------------------------------------------------- +// Config Parsing — Valid +// ----------------------------------------------------------------------------- + +#[test] +fn from_config_minimal_alias_only() { + let yaml: serde_yaml::Value = serde_yaml::from_str( + r#" +model_aliases: + codex-mini-latest: "llama-3.3-70b" +"#, + ) + .unwrap(); + let filter = ModelRewriteFilter::from_config(&yaml).unwrap(); + assert_eq!( + filter.name(), + "openai_responses_model_rewrite", + "filter name should match" + ); +} + +#[test] +fn from_config_minimal_default_only() { + let yaml: serde_yaml::Value = serde_yaml::from_str(r#"default_model: "llama-3.3-70b""#).unwrap(); + let filter = ModelRewriteFilter::from_config(&yaml).unwrap(); + assert_eq!( + filter.name(), + "openai_responses_model_rewrite", + "default-only config should parse" + ); +} + +#[test] +fn from_config_full() { + let yaml: serde_yaml::Value = serde_yaml::from_str( + r#" +default_model: "llama-3.3-70b" +model_aliases: + codex-mini-latest: "llama-3.3-70b" + gpt-4.1-mini: "qwen-2.5-72b" +max_body_bytes: 65536 +on_invalid: reject +headers: + effective_model: x-custom-effective + original_model: x-custom-original +"#, + ) + .unwrap(); + let filter = ModelRewriteFilter::from_config(&yaml).unwrap(); + assert_eq!( + filter.name(), + "openai_responses_model_rewrite", + "full config should parse" + ); +} + +// ----------------------------------------------------------------------------- +// Config Parsing — Rejection +// ----------------------------------------------------------------------------- + +#[test] +fn from_config_rejects_empty_default_model() { + let yaml: serde_yaml::Value = serde_yaml::from_str(r#"default_model: """#).unwrap(); + let result = ModelRewriteFilter::from_config(&yaml); + assert!(result.is_err(), "empty default_model should be rejected"); +} + +#[test] +fn from_config_rejects_whitespace_default_model() { + let yaml: serde_yaml::Value = serde_yaml::from_str(r#"default_model: " ""#).unwrap(); + let result = ModelRewriteFilter::from_config(&yaml); + assert!(result.is_err(), "whitespace-only default_model should be rejected"); +} + +#[test] +fn from_config_rejects_no_default_and_no_aliases() { + let yaml: serde_yaml::Value = serde_yaml::from_str("{}").unwrap(); + let result = ModelRewriteFilter::from_config(&yaml); + assert!( + result.is_err(), + "config with neither default_model nor aliases should be rejected" + ); +} + +#[test] +fn from_config_rejects_empty_alias_source() { + let yaml: serde_yaml::Value = serde_yaml::from_str( + r#" +model_aliases: + "": "llama-3.3-70b" +"#, + ) + .unwrap(); + let result = ModelRewriteFilter::from_config(&yaml); + assert!(result.is_err(), "empty alias source should be rejected"); +} + +#[test] +fn from_config_rejects_empty_alias_target() { + let yaml: serde_yaml::Value = serde_yaml::from_str( + r#" +model_aliases: + codex-mini-latest: "" +"#, + ) + .unwrap(); + let result = ModelRewriteFilter::from_config(&yaml); + assert!(result.is_err(), "empty alias target should be rejected"); +} + +#[test] +fn from_config_rejects_zero_max_body_bytes() { + let yaml: serde_yaml::Value = serde_yaml::from_str( + r#" +default_model: "test" +max_body_bytes: 0 +"#, + ) + .unwrap(); + let result = ModelRewriteFilter::from_config(&yaml); + assert!(result.is_err(), "zero max_body_bytes should be rejected"); +} + +#[test] +fn from_config_rejects_oversized_max_body_bytes() { + let yaml: serde_yaml::Value = serde_yaml::from_str( + r#" +default_model: "test" +max_body_bytes: 67108865 +"#, + ) + .unwrap(); + let result = ModelRewriteFilter::from_config(&yaml); + assert!( + result.is_err(), + "max_body_bytes above 64 MiB ceiling should be rejected" + ); +} + +#[test] +fn from_config_rejects_unknown_fields() { + let yaml: serde_yaml::Value = serde_yaml::from_str( + r#" +default_model: "test" +unknown_field: true +"#, + ) + .unwrap(); + let result = ModelRewriteFilter::from_config(&yaml); + assert!(result.is_err(), "unknown fields should be rejected"); +} + +#[test] +fn from_config_rejects_invalid_header_name() { + let yaml: serde_yaml::Value = serde_yaml::from_str( + r#" +default_model: "test" +headers: + effective_model: "invalid header with spaces" +"#, + ) + .unwrap(); + let result = ModelRewriteFilter::from_config(&yaml); + assert!(result.is_err(), "header name with spaces should be rejected"); +} + +#[test] +fn from_config_rejects_empty_header_name() { + let yaml: serde_yaml::Value = serde_yaml::from_str( + r#" +default_model: "test" +headers: + effective_model: "" +"#, + ) + .unwrap(); + let result = ModelRewriteFilter::from_config(&yaml); + assert!(result.is_err(), "empty header name should be rejected"); +} + +#[test] +fn from_config_null_headers_suppress_promotion() { + let yaml: serde_yaml::Value = serde_yaml::from_str( + r#" +default_model: "test" +headers: + effective_model: null + original_model: null +"#, + ) + .unwrap(); + let filter = ModelRewriteFilter::from_config(&yaml).unwrap(); + assert_eq!( + filter.name(), + "openai_responses_model_rewrite", + "null headers should be accepted" + ); +} + +// ----------------------------------------------------------------------------- +// Trait Properties +// ----------------------------------------------------------------------------- + +#[test] +fn body_access_is_read_write() { + let filter = make_filter(ALIAS_CONFIG); + assert_eq!( + filter.request_body_access(), + BodyAccess::ReadWrite, + "model rewrite must use ReadWrite body access" + ); +} + +#[test] +fn body_mode_is_stream_buffer() { + let filter = make_filter(ALIAS_CONFIG); + match filter.request_body_mode() { + BodyMode::StreamBuffer { max_bytes } => { + assert!(max_bytes.is_some(), "StreamBuffer should have a bounded limit"); + }, + other => panic!("expected StreamBuffer, got {other:?}"), + } +} + +// ----------------------------------------------------------------------------- +// Body Processing — Skip Paths +// ----------------------------------------------------------------------------- + +#[tokio::test] +async fn not_end_of_stream_continues() { + let filter = make_filter(ALIAS_CONFIG); + let req = crate::test_utils::make_request(http::Method::POST, "/v1/responses"); + let mut ctx = crate::test_utils::make_filter_context(&req); + let mut body = Some(Bytes::from(r#"{"model":"gpt-4.1","input":"test"}"#)); + + let action = filter.on_request_body(&mut ctx, &mut body, false).await.unwrap(); + assert!( + matches!(action, FilterAction::Continue), + "non-end-of-stream should continue" + ); +} + +#[tokio::test] +async fn empty_body_continues() { + let filter = make_filter(ALIAS_CONFIG); + let req = crate::test_utils::make_request(http::Method::POST, "/v1/responses"); + let mut ctx = crate::test_utils::make_filter_context(&req); + let mut body: Option = None; + + let action = filter.on_request_body(&mut ctx, &mut body, true).await.unwrap(); + assert!(matches!(action, FilterAction::Continue), "empty body should continue"); +} + +#[tokio::test] +async fn chat_completions_path_skips() { + let filter = make_filter(ALIAS_CONFIG); + let req = crate::test_utils::make_request(http::Method::POST, "/v1/chat/completions"); + let mut ctx = crate::test_utils::make_filter_context(&req); + let mut body = Some(Bytes::from(r#"{"model":"codex-mini-latest","messages":[]}"#)); + + let action = filter.on_request_body(&mut ctx, &mut body, true).await.unwrap(); + assert!( + matches!(action, FilterAction::Continue), + "chat completions path should skip" + ); + assert!( + !ctx.filter_metadata + .contains_key("openai_responses_model_rewrite.effective_model"), + "non-responses path should not set metadata" + ); +} + +#[tokio::test] +async fn get_responses_id_skips() { + let filter = make_filter(REJECT_CONFIG); + let req = crate::test_utils::make_request(http::Method::GET, "/v1/responses/resp_abc123"); + let mut ctx = crate::test_utils::make_filter_context(&req); + let mut body: Option = None; + + let action = filter.on_request_body(&mut ctx, &mut body, true).await.unwrap(); + assert!( + matches!(action, FilterAction::Continue), + "GET /v1/responses/{{id}} should skip even in reject mode" + ); +} + +#[tokio::test] +async fn delete_responses_id_skips() { + let filter = make_filter(REJECT_CONFIG); + let req = crate::test_utils::make_request(http::Method::DELETE, "/v1/responses/resp_abc123"); + let mut ctx = crate::test_utils::make_filter_context(&req); + let mut body: Option = None; + + let action = filter.on_request_body(&mut ctx, &mut body, true).await.unwrap(); + assert!( + matches!(action, FilterAction::Continue), + "DELETE /v1/responses/{{id}} should skip even in reject mode" + ); +} + +#[tokio::test] +async fn post_cancel_subresource_skips() { + let filter = make_filter(REJECT_CONFIG); + let req = crate::test_utils::make_request(http::Method::POST, "/v1/responses/resp_abc/cancel"); + let mut ctx = crate::test_utils::make_filter_context(&req); + let mut body: Option = Some(Bytes::new()); + + let action = filter.on_request_body(&mut ctx, &mut body, true).await.unwrap(); + assert!( + matches!(action, FilterAction::Continue), + "POST /v1/responses/{{id}}/cancel should skip even in reject mode" + ); +} + +#[tokio::test] +async fn post_input_tokens_subresource_skips() { + let filter = make_filter(REJECT_CONFIG); + let req = crate::test_utils::make_request(http::Method::POST, "/v1/responses/input_tokens"); + let mut ctx = crate::test_utils::make_filter_context(&req); + let mut body: Option = Some(Bytes::new()); + + let action = filter.on_request_body(&mut ctx, &mut body, true).await.unwrap(); + assert!( + matches!(action, FilterAction::Continue), + "POST /v1/responses/input_tokens should skip" + ); +} + +#[tokio::test] +async fn post_compact_subresource_skips() { + let filter = make_filter(REJECT_CONFIG); + let req = crate::test_utils::make_request(http::Method::POST, "/v1/responses/compact"); + let mut ctx = crate::test_utils::make_filter_context(&req); + let mut body: Option = Some(Bytes::new()); + + let action = filter.on_request_body(&mut ctx, &mut body, true).await.unwrap(); + assert!( + matches!(action, FilterAction::Continue), + "POST /v1/responses/compact should skip" + ); +} + +#[tokio::test] +async fn trailing_slash_treated_as_create() { + let filter = make_filter(ALIAS_CONFIG); + let req = crate::test_utils::make_request(http::Method::POST, "/v1/responses/"); + let mut ctx = crate::test_utils::make_filter_context(&req); + let mut body = Some(Bytes::from(r#"{"model":"codex-mini-latest","input":"test"}"#)); + + let action = filter.on_request_body(&mut ctx, &mut body, true).await.unwrap(); + assert!( + matches!(action, FilterAction::Continue), + "trailing slash should be treated as create" + ); + assert_eq!( + ctx.filter_metadata + .get("openai_responses_model_rewrite.effective_model") + .map(String::as_str), + Some("llama-3.3-70b"), + "model should be rewritten for trailing-slash create" + ); +} + +#[tokio::test] +async fn reject_mode_rejects_malformed_create() { + let filter = make_filter(REJECT_CONFIG); + let req = crate::test_utils::make_request(http::Method::POST, "/v1/responses"); + let mut ctx = crate::test_utils::make_filter_context(&req); + let mut body = Some(Bytes::from("not valid json {{{")); + + let action = filter.on_request_body(&mut ctx, &mut body, true).await.unwrap(); + assert!( + matches!(action, FilterAction::Reject(_)), + "malformed POST /v1/responses should be rejected in reject mode" + ); +} + +#[tokio::test] +async fn control_char_model_not_promoted_to_metadata() { + let ctx = run_filter(ALIAS_CONFIG, "{\"model\":\"bad\\nmodel\",\"input\":\"test\"}").await; + + assert!( + !ctx.filter_metadata + .contains_key("openai_responses_model_rewrite.effective_model"), + "control-char model should not be written to metadata" + ); + assert!( + !ctx.filter_metadata + .contains_key("openai_responses_model_rewrite.original_model"), + "control-char original model should not be written to metadata" + ); +} + +// ----------------------------------------------------------------------------- +// Body Processing — Invalid Input +// ----------------------------------------------------------------------------- + +#[tokio::test] +async fn invalid_json_continue_leaves_body_unchanged() { + let original = b"not json {{{"; + let ctx = run_filter(ALIAS_CONFIG, std::str::from_utf8(original).unwrap()).await; + + assert!( + !ctx.filter_metadata + .contains_key("openai_responses_model_rewrite.effective_model"), + "invalid JSON in continue mode should not set metadata" + ); +} + +#[tokio::test] +async fn invalid_json_reject_returns_400() { + let action = run_filter_raw(REJECT_CONFIG, "not json {{{").await; + assert!( + matches!(action, FilterAction::Reject(_)), + "invalid JSON in reject mode should return 400" + ); +} + +#[tokio::test] +async fn non_object_json_continue_leaves_body_unchanged() { + let ctx = run_filter(ALIAS_CONFIG, "[1, 2, 3]").await; + + assert!( + !ctx.filter_metadata + .contains_key("openai_responses_model_rewrite.effective_model"), + "non-object JSON in continue mode should not set metadata" + ); +} + +#[tokio::test] +async fn non_object_json_reject_returns_400() { + let action = run_filter_raw(REJECT_CONFIG, "[1, 2, 3]").await; + assert!( + matches!(action, FilterAction::Reject(_)), + "non-object JSON in reject mode should return 400" + ); +} + +// ----------------------------------------------------------------------------- +// Body Processing — Mutation +// ----------------------------------------------------------------------------- + +#[tokio::test] +async fn known_alias_rewrites_model() { + let (ctx, body) = run_filter_with_body(ALIAS_CONFIG, r#"{"model":"codex-mini-latest","input":"test"}"#).await; + + let parsed: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!( + parsed["model"].as_str(), + Some("llama-3.3-70b"), + "model should be rewritten to alias target" + ); + assert_eq!( + ctx.filter_metadata + .get("openai_responses_model_rewrite.effective_model") + .map(String::as_str), + Some("llama-3.3-70b"), + "effective model metadata" + ); + assert_eq!( + ctx.filter_metadata + .get("openai_responses_model_rewrite.original_model") + .map(String::as_str), + Some("codex-mini-latest"), + "original model metadata" + ); + assert_eq!( + ctx.filter_metadata + .get("openai_responses_model_rewrite.rewritten") + .map(String::as_str), + Some("true"), + "rewritten flag" + ); +} + +#[tokio::test] +async fn missing_model_injects_default() { + let (ctx, body) = run_filter_with_body(DEFAULT_CONFIG, r#"{"input":"test"}"#).await; + + let parsed: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!( + parsed["model"].as_str(), + Some("llama-3.3-70b"), + "default model should be injected" + ); + assert_eq!( + ctx.filter_metadata + .get("openai_responses_model_rewrite.default_injected") + .map(String::as_str), + Some("true"), + "default_injected flag" + ); + assert!( + !ctx.filter_metadata + .contains_key("openai_responses_model_rewrite.original_model"), + "no original model when missing" + ); +} + +#[tokio::test] +async fn null_model_injects_default() { + let (ctx, body) = run_filter_with_body(DEFAULT_CONFIG, r#"{"model":null,"input":"test"}"#).await; + + let parsed: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!( + parsed["model"].as_str(), + Some("llama-3.3-70b"), + "null model should be replaced with default" + ); + assert_eq!( + ctx.filter_metadata + .get("openai_responses_model_rewrite.default_injected") + .map(String::as_str), + Some("true"), + "default_injected flag" + ); +} + +#[tokio::test] +async fn unknown_model_passes_unchanged() { + let (ctx, body) = run_filter_with_body(ALIAS_CONFIG, r#"{"model":"unknown-model","input":"test"}"#).await; + + let parsed: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!( + parsed["model"].as_str(), + Some("unknown-model"), + "unknown model should pass through unchanged" + ); + assert_eq!( + ctx.filter_metadata + .get("openai_responses_model_rewrite.effective_model") + .map(String::as_str), + Some("unknown-model"), + "effective model should equal original" + ); + assert!( + !ctx.filter_metadata + .contains_key("openai_responses_model_rewrite.rewritten"), + "rewritten flag should not be set" + ); +} + +#[tokio::test] +async fn no_default_and_no_alias_match_is_noop() { + let (ctx, body) = run_filter_with_body(ALIAS_CONFIG, r#"{"model":"unknown-model","input":"test"}"#).await; + + let parsed: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!( + parsed["model"].as_str(), + Some("unknown-model"), + "body should not be mutated" + ); + assert!( + ctx.extra_request_headers + .iter() + .all(|(k, _)| k.as_ref() != "content-length"), + "content-length should not be added when body is unmodified" + ); +} + +#[tokio::test] +async fn non_string_model_is_noop() { + let (ctx, body) = run_filter_with_body(DEFAULT_CONFIG, r#"{"model":123,"input":"test"}"#).await; + + let parsed: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!( + parsed["model"].as_u64(), + Some(123), + "numeric model should pass through unchanged" + ); + assert!( + !ctx.filter_metadata + .contains_key("openai_responses_model_rewrite.default_injected"), + "default should not be injected for non-string model" + ); + assert!( + ctx.extra_request_headers + .iter() + .all(|(k, _)| k.as_ref() != "content-length"), + "content-length should not be added when body is unmodified" + ); +} + +#[tokio::test] +async fn object_model_is_noop() { + let (ctx, body) = run_filter_with_body(DEFAULT_CONFIG, r#"{"model":{},"input":"test"}"#).await; + + let parsed: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert!( + parsed["model"].is_object(), + "object model should pass through unchanged" + ); + assert!( + !ctx.filter_metadata + .contains_key("openai_responses_model_rewrite.default_injected"), + "default should not be injected for object model" + ); +} + +// ----------------------------------------------------------------------------- +// Body Processing — Field Preservation +// ----------------------------------------------------------------------------- + +#[tokio::test] +async fn preserves_input_tools_instructions_and_unknown_fields() { + let input_body = r#"{"model":"codex-mini-latest","input":[{"role":"user","content":"Hello"}],"instructions":"Be helpful","tools":[{"type":"function","name":"read_file"}],"custom_field":"preserved","stream":true}"#; + + let (_ctx, body) = run_filter_with_body(ALIAS_CONFIG, input_body).await; + + let parsed: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!( + parsed["model"].as_str(), + Some("llama-3.3-70b"), + "model should be rewritten" + ); + assert!(parsed["input"].is_array(), "input should be preserved"); + assert_eq!( + parsed["instructions"].as_str(), + Some("Be helpful"), + "instructions should be preserved" + ); + assert!(parsed["tools"].is_array(), "tools should be preserved"); + assert_eq!( + parsed["tools"][0]["name"].as_str(), + Some("read_file"), + "tool name should be preserved" + ); + assert_eq!( + parsed["custom_field"].as_str(), + Some("preserved"), + "unknown fields should be preserved" + ); + assert_eq!(parsed["stream"].as_bool(), Some(true), "stream should be preserved"); +} + +#[tokio::test] +async fn preserves_function_call_output_items() { + let input_body = r#"{"model":"codex-mini-latest","input":[{"type":"function_call_output","call_id":"call_abc","output":"{\"result\":42}"},{"role":"user","content":"What happened?"}]}"#; + + let (_ctx, body) = run_filter_with_body(ALIAS_CONFIG, input_body).await; + + let parsed: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(parsed["model"].as_str(), Some("llama-3.3-70b"), "model rewritten"); + let items = parsed["input"].as_array().unwrap(); + assert_eq!(items.len(), 2, "input items preserved"); + assert_eq!( + items[0]["type"].as_str(), + Some("function_call_output"), + "function_call_output type preserved" + ); + assert_eq!(items[0]["call_id"].as_str(), Some("call_abc"), "call_id preserved"); +} + +// ----------------------------------------------------------------------------- +// Content-Length Behavior +// ----------------------------------------------------------------------------- + +#[tokio::test] +async fn updates_content_length_when_mutated() { + let (ctx, body) = run_filter_with_body(ALIAS_CONFIG, r#"{"model":"codex-mini-latest","input":"test"}"#).await; + + let cl_header = ctx + .extra_request_headers + .iter() + .find(|(k, _)| k.as_ref() == "content-length") + .map(|(_, v)| v.as_str()); + + assert!(cl_header.is_some(), "content-length should be set after mutation"); + let cl_value: usize = cl_header.unwrap().parse().unwrap(); + assert_eq!( + cl_value, + body.len(), + "content-length should match serialized body length" + ); +} + +#[tokio::test] +async fn does_not_add_content_length_when_unmodified() { + let ctx = run_filter(ALIAS_CONFIG, r#"{"model":"unknown-model","input":"test"}"#).await; + + assert!( + ctx.extra_request_headers + .iter() + .all(|(k, _)| k.as_ref() != "content-length"), + "content-length should not be added when body is unmodified" + ); +} + +// ----------------------------------------------------------------------------- +// Header and Metadata Promotion +// ----------------------------------------------------------------------------- + +#[tokio::test] +async fn sets_effective_model_header_and_metadata() { + let ctx = run_filter(ALIAS_CONFIG, r#"{"model":"codex-mini-latest","input":"test"}"#).await; + let headers = collect_headers(&ctx); + + assert_eq!( + headers.get("x-praxis-ai-effective-model"), + Some(&"llama-3.3-70b"), + "effective model header" + ); +} + +#[tokio::test] +async fn sets_original_model_header_when_present() { + let ctx = run_filter(ALIAS_CONFIG, r#"{"model":"codex-mini-latest","input":"test"}"#).await; + let headers = collect_headers(&ctx); + + assert_eq!( + headers.get("x-praxis-ai-original-model"), + Some(&"codex-mini-latest"), + "original model header" + ); +} + +#[tokio::test] +async fn no_original_model_header_when_model_absent() { + let ctx = run_filter(DEFAULT_CONFIG, r#"{"input":"test"}"#).await; + let headers = collect_headers(&ctx); + + assert!( + !headers.contains_key("x-praxis-ai-original-model"), + "no original model header when model was absent" + ); + assert_eq!( + headers.get("x-praxis-ai-effective-model"), + Some(&"llama-3.3-70b"), + "effective model header should be set for default injection" + ); +} + +#[tokio::test] +async fn custom_headers_emitted() { + let cfg = r#" +default_model: "test-model" +headers: + effective_model: x-custom-eff + original_model: x-custom-orig +"#; + let ctx = run_filter(cfg, r#"{"model":"old","input":"test"}"#).await; + let headers = collect_headers(&ctx); + + assert!( + !headers.contains_key("x-praxis-ai-effective-model"), + "default header should not be emitted when overridden" + ); + assert_eq!( + headers.get("x-custom-eff"), + Some(&"old"), + "custom effective model header" + ); + assert_eq!( + headers.get("x-custom-orig"), + Some(&"old"), + "custom original model header" + ); +} + +#[tokio::test] +async fn null_headers_suppress_emission() { + let cfg = r#" +default_model: "test-model" +headers: + effective_model: null + original_model: null +"#; + let ctx = run_filter(cfg, r#"{"model":"old","input":"test"}"#).await; + let headers = collect_headers(&ctx); + + assert!( + !headers.contains_key("x-praxis-ai-effective-model"), + "null effective header suppressed" + ); + assert!( + !headers.contains_key("x-praxis-ai-original-model"), + "null original header suppressed" + ); + assert_eq!( + ctx.filter_metadata + .get("openai_responses_model_rewrite.effective_model") + .map(String::as_str), + Some("old"), + "metadata still written even with null headers" + ); +} + +// ----------------------------------------------------------------------------- +// Filter Results +// ----------------------------------------------------------------------------- + +#[tokio::test] +async fn filter_results_record_rewrite_decision() { + let ctx = run_filter(ALIAS_CONFIG, r#"{"model":"codex-mini-latest","input":"test"}"#).await; + let results = ctx.filter_results.get("openai_responses_model_rewrite").unwrap(); + + assert_eq!( + results.get("effective_model"), + Some("llama-3.3-70b"), + "effective_model filter result" + ); + assert_eq!(results.get("rewritten"), Some("true"), "rewritten filter result"); + assert!( + results.get("default_injected").is_none(), + "default_injected should not be set for alias rewrite" + ); +} + +#[tokio::test] +async fn filter_results_record_default_injection() { + let ctx = run_filter(DEFAULT_CONFIG, r#"{"input":"test"}"#).await; + let results = ctx.filter_results.get("openai_responses_model_rewrite").unwrap(); + + assert_eq!( + results.get("effective_model"), + Some("llama-3.3-70b"), + "effective_model filter result" + ); + assert_eq!( + results.get("default_injected"), + Some("true"), + "default_injected filter result" + ); + assert!( + results.get("rewritten").is_none(), + "rewritten should not be set for default injection" + ); +} + +#[tokio::test] +async fn on_request_repopulates_filter_results_from_metadata() { + let filter = make_filter(ALIAS_CONFIG); + let req: &'static crate::context::Request = Box::leak(Box::new(crate::test_utils::make_request( + http::Method::POST, + "/v1/responses", + ))); + let mut ctx = crate::test_utils::make_filter_context(req); + + ctx.set_metadata("openai_responses_model_rewrite.effective_model", "llama-3.3-70b"); + ctx.set_metadata("openai_responses_model_rewrite.rewritten", "true"); + + ctx.filter_results.clear(); + assert!( + !ctx.filter_results.contains_key("openai_responses_model_rewrite"), + "results should be cleared before on_request" + ); + + let action = filter.on_request(&mut ctx).await.unwrap(); + assert!(matches!(action, FilterAction::Continue), "on_request should continue"); + + let results = ctx.filter_results.get("openai_responses_model_rewrite").unwrap(); + assert_eq!( + results.get("effective_model"), + Some("llama-3.3-70b"), + "on_request should repopulate effective_model from metadata" + ); + assert_eq!( + results.get("rewritten"), + Some("true"), + "on_request should repopulate rewritten from metadata" + ); +} + +#[tokio::test] +async fn on_request_skips_repopulation_when_no_metadata() { + let filter = make_filter(ALIAS_CONFIG); + let req: &'static crate::context::Request = Box::leak(Box::new(crate::test_utils::make_request( + http::Method::POST, + "/v1/responses", + ))); + let mut ctx = crate::test_utils::make_filter_context(req); + + let action = filter.on_request(&mut ctx).await.unwrap(); + assert!(matches!(action, FilterAction::Continue), "on_request should continue"); + + assert!( + !ctx.filter_results.contains_key("openai_responses_model_rewrite"), + "on_request should not create results when no metadata exists" + ); +} + +// ----------------------------------------------------------------------------- +// Test Utilities +// ----------------------------------------------------------------------------- + +/// Config with only alias mapping. +const ALIAS_CONFIG: &str = r#" +model_aliases: + codex-mini-latest: "llama-3.3-70b" + gpt-4.1-mini: "qwen-2.5-72b" +"#; + +/// Config with only default model. +const DEFAULT_CONFIG: &str = r#"default_model: "llama-3.3-70b""#; + +/// Config with reject behavior. +const REJECT_CONFIG: &str = r#" +model_aliases: + codex-mini-latest: "llama-3.3-70b" +on_invalid: reject +"#; + +/// Build a [`ModelRewriteFilter`] from a YAML snippet. +fn make_filter(yaml_str: &str) -> Box { + let yaml: serde_yaml::Value = serde_yaml::from_str(yaml_str).unwrap(); + ModelRewriteFilter::from_config(&yaml).unwrap() +} + +/// Run the filter and return the resulting context. +async fn run_filter(config_yaml: &str, body_str: &str) -> HttpFilterContext<'static> { + let filter = make_filter(config_yaml); + let req: &'static crate::context::Request = Box::leak(Box::new(crate::test_utils::make_request( + http::Method::POST, + "/v1/responses", + ))); + let mut ctx = crate::test_utils::make_filter_context(req); + let mut body = Some(Bytes::from(body_str.to_owned())); + + let action = filter.on_request_body(&mut ctx, &mut body, true).await.unwrap(); + assert!( + matches!(action, FilterAction::Continue), + "valid request should continue: got {action:?}" + ); + ctx +} + +/// Run the filter and return both context and final body bytes. +async fn run_filter_with_body(config_yaml: &str, body_str: &str) -> (HttpFilterContext<'static>, Bytes) { + let filter = make_filter(config_yaml); + let req: &'static crate::context::Request = Box::leak(Box::new(crate::test_utils::make_request( + http::Method::POST, + "/v1/responses", + ))); + let mut ctx = crate::test_utils::make_filter_context(req); + let mut body = Some(Bytes::from(body_str.to_owned())); + + let action = filter.on_request_body(&mut ctx, &mut body, true).await.unwrap(); + assert!( + matches!(action, FilterAction::Continue), + "valid request should continue: got {action:?}" + ); + (ctx, body.unwrap()) +} + +/// Run the filter and return the raw action. +async fn run_filter_raw(config_yaml: &str, body_str: &str) -> FilterAction { + let filter = make_filter(config_yaml); + let req: &'static crate::context::Request = Box::leak(Box::new(crate::test_utils::make_request( + http::Method::POST, + "/v1/responses", + ))); + let mut ctx = crate::test_utils::make_filter_context(req); + let mut body = Some(Bytes::from(body_str.to_owned())); + + filter.on_request_body(&mut ctx, &mut body, true).await.unwrap() +} + +/// Collect extra request headers into a map for assertions. +fn collect_headers<'a>(ctx: &'a HttpFilterContext<'_>) -> HashMap<&'a str, &'a str> { + ctx.extra_request_headers + .iter() + .map(|(k, v)| (k.as_ref(), v.as_str())) + .collect() +} diff --git a/filter/src/builtins/http/mod.rs b/filter/src/builtins/http/mod.rs index cb933519..f1622522 100644 --- a/filter/src/builtins/http/mod.rs +++ b/filter/src/builtins/http/mod.rs @@ -24,6 +24,8 @@ pub use ai::AnthropicToOpenaiFilter; #[cfg(feature = "ai-inference")] pub use ai::AnthropicValidateFilter; #[cfg(feature = "ai-inference")] +pub use ai::ModelRewriteFilter; +#[cfg(feature = "ai-inference")] pub use ai::ModelToHeaderFilter; #[cfg(feature = "ai-inference")] pub use ai::OpenaiResponsesValidateFilter; diff --git a/filter/src/builtins/mod.rs b/filter/src/builtins/mod.rs index 23c99ada..08b8525e 100644 --- a/filter/src/builtins/mod.rs +++ b/filter/src/builtins/mod.rs @@ -19,6 +19,8 @@ pub use http::AnthropicToOpenaiFilter; #[cfg(feature = "ai-inference")] pub use http::AnthropicValidateFilter; #[cfg(feature = "ai-inference")] +pub use http::ModelRewriteFilter; +#[cfg(feature = "ai-inference")] pub use http::ModelToHeaderFilter; #[cfg(feature = "ai-inference")] pub use http::OpenaiResponsesValidateFilter; diff --git a/filter/src/registry.rs b/filter/src/registry.rs index 87d2f258..392995d3 100644 --- a/filter/src/registry.rs +++ b/filter/src/registry.rs @@ -199,6 +199,12 @@ fn register_http_builtins(factories: &mut HashMap) { crate::builtins::ResponsesFormatFilter::from_config, ); #[cfg(feature = "ai-inference")] + register_http( + factories, + "openai_responses_model_rewrite", + crate::builtins::ModelRewriteFilter::from_config, + ); + #[cfg(feature = "ai-inference")] register_http( factories, "openai_responses_validate", @@ -363,6 +369,11 @@ mod tests { "openai_responses_format should be registered" ); #[cfg(feature = "ai-inference")] + assert!( + names.contains(&"openai_responses_model_rewrite"), + "openai_responses_model_rewrite should be registered" + ); + #[cfg(feature = "ai-inference")] assert!( names.contains(&"openai_responses_validate"), "validate should be registered" diff --git a/tests/integration/tests/suite/examples/mod.rs b/tests/integration/tests/suite/examples/mod.rs index f814c557..43a73515 100644 --- a/tests/integration/tests/suite/examples/mod.rs +++ b/tests/integration/tests/suite/examples/mod.rs @@ -41,6 +41,8 @@ mod openai_response_store_postgres; #[cfg(feature = "ai-inference")] mod openai_responses_format; #[cfg(feature = "ai-inference")] +mod openai_responses_model_rewrite; +#[cfg(feature = "ai-inference")] mod openai_responses_validate; mod p2c; mod path_based_routing; diff --git a/tests/integration/tests/suite/examples/openai_responses_model_rewrite.rs b/tests/integration/tests/suite/examples/openai_responses_model_rewrite.rs new file mode 100644 index 00000000..f461999d --- /dev/null +++ b/tests/integration/tests/suite/examples/openai_responses_model_rewrite.rs @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 Praxis Contributors + +//! Functional test for the model-rewrite example config. + +use std::collections::HashMap; + +use praxis_test_utils::{ + free_port, http_send, json_post, load_example_config, parse_body, parse_status, start_backend_with_shutdown, + start_proxy, +}; + +// ----------------------------------------------------------------------------- +// Tests +// ----------------------------------------------------------------------------- + +#[test] +fn example_config_alias_routes_to_llama_backend() { + let llama_guard = start_backend_with_shutdown("llama-backend"); + let qwen_guard = start_backend_with_shutdown("qwen-backend"); + let default_guard = start_backend_with_shutdown("default-backend"); + let proxy_port = free_port(); + + let config = load_example_config( + "ai/openai/responses/model-rewrite.yaml", + proxy_port, + HashMap::from([ + ("127.0.0.1:3001", llama_guard.port()), + ("127.0.0.1:3002", qwen_guard.port()), + ("127.0.0.1:3003", default_guard.port()), + ]), + ); + let proxy = start_proxy(&config); + + let body = r#"{"model":"codex-mini-latest","input":"Hello from example test"}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/responses", body)); + + assert_eq!(parse_status(&raw), 200, "example config should route successfully"); + assert_eq!( + parse_body(&raw), + "llama-backend", + "codex-mini-latest should route to llama-backend via effective model header" + ); +} + +#[test] +fn example_config_default_model_routes_to_llama_backend() { + let llama_guard = start_backend_with_shutdown("llama-backend"); + let qwen_guard = start_backend_with_shutdown("qwen-backend"); + let default_guard = start_backend_with_shutdown("default-backend"); + let proxy_port = free_port(); + + let config = load_example_config( + "ai/openai/responses/model-rewrite.yaml", + proxy_port, + HashMap::from([ + ("127.0.0.1:3001", llama_guard.port()), + ("127.0.0.1:3002", qwen_guard.port()), + ("127.0.0.1:3003", default_guard.port()), + ]), + ); + let proxy = start_proxy(&config); + + let body = r#"{"input":"No model specified"}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/responses", body)); + + assert_eq!(parse_status(&raw), 200, "default model injection should succeed"); + assert_eq!( + parse_body(&raw), + "llama-backend", + "default_model llama-3.3-70b should route to llama-backend" + ); +} + +#[test] +fn example_config_qwen_alias_routes_to_qwen_backend() { + let llama_guard = start_backend_with_shutdown("llama-backend"); + let qwen_guard = start_backend_with_shutdown("qwen-backend"); + let default_guard = start_backend_with_shutdown("default-backend"); + let proxy_port = free_port(); + + let config = load_example_config( + "ai/openai/responses/model-rewrite.yaml", + proxy_port, + HashMap::from([ + ("127.0.0.1:3001", llama_guard.port()), + ("127.0.0.1:3002", qwen_guard.port()), + ("127.0.0.1:3003", default_guard.port()), + ]), + ); + let proxy = start_proxy(&config); + + let body = r#"{"model":"gpt-4.1-mini","input":"Route to qwen"}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/responses", body)); + + assert_eq!(parse_status(&raw), 200, "qwen alias should route successfully"); + assert_eq!( + parse_body(&raw), + "qwen-backend", + "gpt-4.1-mini should route to qwen-backend via effective model header" + ); +} + +#[test] +fn example_config_non_responses_routes_to_default() { + let llama_guard = start_backend_with_shutdown("llama-backend"); + let qwen_guard = start_backend_with_shutdown("qwen-backend"); + let default_guard = start_backend_with_shutdown("default-backend"); + let proxy_port = free_port(); + + let config = load_example_config( + "ai/openai/responses/model-rewrite.yaml", + proxy_port, + HashMap::from([ + ("127.0.0.1:3001", llama_guard.port()), + ("127.0.0.1:3002", qwen_guard.port()), + ("127.0.0.1:3003", default_guard.port()), + ]), + ); + let proxy = start_proxy(&config); + + let body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hi"}]}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/chat/completions", body)); + + assert_eq!(parse_status(&raw), 200, "chat completions should return 200"); + assert_eq!( + parse_body(&raw), + "default-backend", + "non-responses traffic should route to default" + ); +} diff --git a/tests/integration/tests/suite/main.rs b/tests/integration/tests/suite/main.rs index ad05dcd2..178e0ea6 100644 --- a/tests/integration/tests/suite/main.rs +++ b/tests/integration/tests/suite/main.rs @@ -63,6 +63,8 @@ mod mcp; mod mcp_broker; #[cfg(feature = "ai-inference")] mod openai_responses_format; +#[cfg(feature = "ai-inference")] +mod openai_responses_model_rewrite; mod path_rewrite; mod payload_processing; mod per_listener_pipeline; diff --git a/tests/integration/tests/suite/openai_responses_model_rewrite.rs b/tests/integration/tests/suite/openai_responses_model_rewrite.rs new file mode 100644 index 00000000..c694235a --- /dev/null +++ b/tests/integration/tests/suite/openai_responses_model_rewrite.rs @@ -0,0 +1,477 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 Praxis Contributors + +use praxis_core::config::Config; +use praxis_test_utils::{ + free_port, http_send, json_post, parse_body, parse_status, start_backend_with_shutdown, start_echo_backend, + start_header_echo_backend, start_proxy, +}; + +// ----------------------------------------------------------------------------- +// Model Alias Rewrite +// ----------------------------------------------------------------------------- + +#[test] +fn responses_model_alias_reaches_upstream() { + let echo_guard = start_echo_backend(); + let proxy_port = free_port(); + + let config = Config::from_yaml(&rewrite_yaml(proxy_port, echo_guard.port())).unwrap(); + let proxy = start_proxy(&config); + + let body = r#"{"model":"codex-mini-latest","input":"Hello, world!"}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/responses", body)); + + assert_eq!(parse_status(&raw), 200, "proxy should return 200"); + let echoed: serde_json::Value = serde_json::from_str(&parse_body(&raw)).unwrap(); + assert_eq!( + echoed["model"].as_str(), + Some("llama-3.3-70b"), + "model should be rewritten to alias target" + ); +} + +// ----------------------------------------------------------------------------- +// Default Model Injection +// ----------------------------------------------------------------------------- + +#[test] +fn responses_default_model_reaches_upstream() { + let echo_guard = start_echo_backend(); + let proxy_port = free_port(); + + let config = Config::from_yaml(&rewrite_yaml(proxy_port, echo_guard.port())).unwrap(); + let proxy = start_proxy(&config); + + let body = r#"{"input":"Hello, no model specified"}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/responses", body)); + + assert_eq!(parse_status(&raw), 200, "proxy should return 200"); + let echoed: serde_json::Value = serde_json::from_str(&parse_body(&raw)).unwrap(); + assert_eq!( + echoed["model"].as_str(), + Some("llama-3.3-70b"), + "default model should be injected" + ); +} + +// ----------------------------------------------------------------------------- +// Unknown Model Passthrough +// ----------------------------------------------------------------------------- + +#[test] +fn responses_unknown_model_reaches_upstream_unchanged() { + let echo_guard = start_echo_backend(); + let proxy_port = free_port(); + + let config = Config::from_yaml(&rewrite_yaml(proxy_port, echo_guard.port())).unwrap(); + let proxy = start_proxy(&config); + + let body = r#"{"model":"some-unknown-model","input":"test"}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/responses", body)); + + assert_eq!(parse_status(&raw), 200, "proxy should return 200"); + let echoed: serde_json::Value = serde_json::from_str(&parse_body(&raw)).unwrap(); + assert_eq!( + echoed["model"].as_str(), + Some("some-unknown-model"), + "unknown model should pass through unchanged" + ); +} + +// ----------------------------------------------------------------------------- +// Tool Preservation +// ----------------------------------------------------------------------------- + +#[test] +fn responses_tool_request_preserves_tools() { + let echo_guard = start_echo_backend(); + let proxy_port = free_port(); + + let config = Config::from_yaml(&rewrite_yaml(proxy_port, echo_guard.port())).unwrap(); + let proxy = start_proxy(&config); + + let body = r#"{"model":"codex-mini-latest","input":"test","tools":[{"type":"function","name":"read_file","description":"Read a file","parameters":{"type":"object","properties":{"path":{"type":"string"}}}}]}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/responses", body)); + + assert_eq!(parse_status(&raw), 200, "proxy should return 200"); + let echoed: serde_json::Value = serde_json::from_str(&parse_body(&raw)).unwrap(); + assert_eq!( + echoed["model"].as_str(), + Some("llama-3.3-70b"), + "model should be rewritten" + ); + assert!(echoed["tools"].is_array(), "tools array should be preserved"); + assert_eq!( + echoed["tools"][0]["name"].as_str(), + Some("read_file"), + "tool name should be preserved" + ); + assert!( + echoed["tools"][0]["parameters"].is_object(), + "tool parameters should be preserved" + ); +} + +// ----------------------------------------------------------------------------- +// Function Call Output Preservation +// ----------------------------------------------------------------------------- + +#[test] +fn responses_function_call_output_preserved() { + let echo_guard = start_echo_backend(); + let proxy_port = free_port(); + + let config = Config::from_yaml(&rewrite_yaml(proxy_port, echo_guard.port())).unwrap(); + let proxy = start_proxy(&config); + + let body = r#"{"model":"codex-mini-latest","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"result\":42}"},{"role":"user","content":"What happened?"}]}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/responses", body)); + + assert_eq!(parse_status(&raw), 200, "proxy should return 200"); + let echoed: serde_json::Value = serde_json::from_str(&parse_body(&raw)).unwrap(); + assert_eq!(echoed["model"].as_str(), Some("llama-3.3-70b"), "model rewritten"); + let items = echoed["input"].as_array().unwrap(); + assert_eq!(items.len(), 2, "input items should be preserved"); + assert_eq!( + items[0]["type"].as_str(), + Some("function_call_output"), + "function_call_output type preserved" + ); + assert_eq!(items[0]["call_id"].as_str(), Some("call_abc123"), "call_id preserved"); + assert_eq!( + items[0]["output"].as_str(), + Some("{\"result\":42}"), + "output field preserved" + ); +} + +// ----------------------------------------------------------------------------- +// Non-Responses Traffic +// ----------------------------------------------------------------------------- + +#[test] +fn non_responses_chat_body_passes_unchanged() { + let echo_guard = start_echo_backend(); + let proxy_port = free_port(); + + let config = Config::from_yaml(&rewrite_yaml(proxy_port, echo_guard.port())).unwrap(); + let proxy = start_proxy(&config); + + let body = r#"{"model":"gpt-4","messages":[{"role":"user","content":"Hi"}]}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/chat/completions", body)); + + assert_eq!(parse_status(&raw), 200, "chat completions should return 200"); + assert_eq!( + parse_body(&raw), + body, + "chat completions body should pass through unchanged" + ); +} + +// ----------------------------------------------------------------------------- +// Content-Length +// ----------------------------------------------------------------------------- + +#[test] +fn content_length_matches_mutated_body() { + let echo_guard = start_echo_backend(); + let proxy_port = free_port(); + + let config = Config::from_yaml(&rewrite_yaml(proxy_port, echo_guard.port())).unwrap(); + let proxy = start_proxy(&config); + + let body = r#"{"model":"codex-mini-latest","input":"test"}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/responses", body)); + + assert_eq!(parse_status(&raw), 200, "proxy should return 200"); + let echoed_body = parse_body(&raw); + let echoed_len = echoed_body.len(); + + drop(proxy); + drop(echo_guard); + + let header_guard = start_header_echo_backend(); + let proxy_port2 = free_port(); + let config2 = Config::from_yaml(&header_echo_yaml(proxy_port2, header_guard.port())).unwrap(); + let proxy2 = start_proxy(&config2); + + let raw2 = http_send(proxy2.addr(), &json_post("/v1/responses", body)); + assert_eq!(parse_status(&raw2), 200, "header echo should return 200"); + let headers_text = parse_body(&raw2); + + let cl_lines: Vec<&str> = headers_text + .lines() + .filter(|l| l.to_lowercase().starts_with("content-length:")) + .collect(); + assert_eq!( + cl_lines.len(), + 1, + "backend should receive exactly one Content-Length header, got {cl_lines:?}" + ); + + let cl_value: usize = cl_lines[0] + .split(':') + .nth(1) + .unwrap() + .trim() + .parse() + .expect("content-length should be a number"); + assert_eq!( + cl_value, echoed_len, + "upstream content-length ({cl_value}) should match rewritten body length ({echoed_len})" + ); +} + +// ----------------------------------------------------------------------------- +// Effective Model Header Routing +// ----------------------------------------------------------------------------- + +#[test] +fn effective_model_header_routes_to_expected_backend() { + let llama_guard = start_backend_with_shutdown("llama-backend"); + let qwen_guard = start_backend_with_shutdown("qwen-backend"); + let default_guard = start_backend_with_shutdown("default-backend"); + let proxy_port = free_port(); + + let config = Config::from_yaml(&effective_model_routing_yaml( + proxy_port, + llama_guard.port(), + qwen_guard.port(), + default_guard.port(), + )) + .unwrap(); + let proxy = start_proxy(&config); + + let body = r#"{"model":"codex-mini-latest","input":"route to llama"}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/responses", body)); + assert_eq!(parse_status(&raw), 200, "llama route should return 200"); + assert_eq!( + parse_body(&raw), + "llama-backend", + "codex-mini-latest should route to llama-backend via effective model header" + ); + + let body = r#"{"model":"gpt-4.1-mini","input":"route to qwen"}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/responses", body)); + assert_eq!(parse_status(&raw), 200, "qwen route should return 200"); + assert_eq!( + parse_body(&raw), + "qwen-backend", + "gpt-4.1-mini should route to qwen-backend via effective model header" + ); + + let body = r#"{"model":"unknown-model","input":"route to default"}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/responses", body)); + assert_eq!(parse_status(&raw), 200, "default route should return 200"); + assert_eq!( + parse_body(&raw), + "default-backend", + "unknown model should route to default-backend" + ); +} + +// ----------------------------------------------------------------------------- +// Null Model Default Injection +// ----------------------------------------------------------------------------- + +#[test] +fn responses_null_model_receives_default() { + let echo_guard = start_echo_backend(); + let proxy_port = free_port(); + + let config = Config::from_yaml(&rewrite_yaml(proxy_port, echo_guard.port())).unwrap(); + let proxy = start_proxy(&config); + + let body = r#"{"model":null,"input":"null model test"}"#; + let raw = http_send(proxy.addr(), &json_post("/v1/responses", body)); + + assert_eq!(parse_status(&raw), 200, "proxy should return 200"); + let echoed: serde_json::Value = serde_json::from_str(&parse_body(&raw)).unwrap(); + assert_eq!( + echoed["model"].as_str(), + Some("llama-3.3-70b"), + "null model should receive configured default" + ); +} + +// ----------------------------------------------------------------------------- +// Malformed JSON Rejection +// ----------------------------------------------------------------------------- + +#[test] +fn malformed_json_rejected_in_reject_mode() { + let echo_guard = start_echo_backend(); + let proxy_port = free_port(); + + let config = Config::from_yaml(&reject_yaml(proxy_port, echo_guard.port())).unwrap(); + let proxy = start_proxy(&config); + + let raw = http_send(proxy.addr(), &json_post("/v1/responses", "not valid json {{{")); + + assert_eq!(parse_status(&raw), 400, "malformed JSON should return 400"); + let error_body: serde_json::Value = serde_json::from_str(&parse_body(&raw)).unwrap(); + assert_eq!( + error_body["error"]["type"].as_str(), + Some("invalid_request_error"), + "rejection should have structured error type" + ); + assert_eq!( + error_body["error"]["message"].as_str(), + Some("invalid JSON body"), + "rejection should have descriptive message" + ); +} + +#[test] +fn malformed_json_continues_in_continue_mode() { + let echo_guard = start_echo_backend(); + let proxy_port = free_port(); + + let config = Config::from_yaml(&rewrite_yaml(proxy_port, echo_guard.port())).unwrap(); + let proxy = start_proxy(&config); + + let raw = http_send(proxy.addr(), &json_post("/v1/responses", "not valid json {{{")); + + assert_eq!( + parse_status(&raw), + 200, + "malformed JSON in continue mode should pass through" + ); +} + +// ----------------------------------------------------------------------------- +// Test Utilities +// ----------------------------------------------------------------------------- + +/// YAML config with model rewrite + `on_invalid: reject` for malformed body testing. +fn reject_yaml(proxy_port: u16, backend_port: u16) -> String { + format!( + r#" +listeners: + - name: default + address: "127.0.0.1:{proxy_port}" + filter_chains: [main] +filter_chains: + - name: main + filters: + - filter: openai_responses_format + - filter: openai_responses_model_rewrite + default_model: "llama-3.3-70b" + model_aliases: + codex-mini-latest: "llama-3.3-70b" + on_invalid: reject + - filter: router + routes: + - path_prefix: "/" + cluster: "backend" + - filter: load_balancer + clusters: + - name: "backend" + endpoints: + - "127.0.0.1:{backend_port}" +"# + ) +} + +/// YAML config with model rewrite + echo backend. +fn rewrite_yaml(proxy_port: u16, backend_port: u16) -> String { + format!( + r#" +listeners: + - name: default + address: "127.0.0.1:{proxy_port}" + filter_chains: [main] +filter_chains: + - name: main + filters: + - filter: openai_responses_format + - filter: openai_responses_model_rewrite + default_model: "llama-3.3-70b" + model_aliases: + codex-mini-latest: "llama-3.3-70b" + gpt-4.1-mini: "qwen-2.5-72b" + - filter: router + routes: + - path_prefix: "/" + cluster: "backend" + - filter: load_balancer + clusters: + - name: "backend" + endpoints: + - "127.0.0.1:{backend_port}" +"# + ) +} + +/// YAML config with header-echo backend for content-length verification. +fn header_echo_yaml(proxy_port: u16, backend_port: u16) -> String { + format!( + r#" +listeners: + - name: default + address: "127.0.0.1:{proxy_port}" + filter_chains: [main] +filter_chains: + - name: main + filters: + - filter: openai_responses_format + - filter: openai_responses_model_rewrite + default_model: "llama-3.3-70b" + model_aliases: + codex-mini-latest: "llama-3.3-70b" + - filter: router + routes: + - path_prefix: "/" + cluster: "backend" + - filter: load_balancer + clusters: + - name: "backend" + endpoints: + - "127.0.0.1:{backend_port}" +"# + ) +} + +/// YAML config that routes by effective model header to different backends. +fn effective_model_routing_yaml(proxy_port: u16, llama_port: u16, qwen_port: u16, default_port: u16) -> String { + format!( + r#" +listeners: + - name: default + address: "127.0.0.1:{proxy_port}" + filter_chains: [main] +filter_chains: + - name: main + filters: + - filter: openai_responses_format + - filter: openai_responses_model_rewrite + model_aliases: + codex-mini-latest: "llama-3.3-70b" + gpt-4.1-mini: "qwen-2.5-72b" + - filter: router + routes: + - path: "/v1/responses" + headers: + x-praxis-ai-effective-model: "llama-3.3-70b" + cluster: "llama" + - path: "/v1/responses" + headers: + x-praxis-ai-effective-model: "qwen-2.5-72b" + cluster: "qwen" + - path_prefix: "/" + cluster: "default" + - filter: load_balancer + clusters: + - name: "llama" + endpoints: + - "127.0.0.1:{llama_port}" + - name: "qwen" + endpoints: + - "127.0.0.1:{qwen_port}" + - name: "default" + endpoints: + - "127.0.0.1:{default_port}" +"# + ) +}