diff --git a/.gitignore b/.gitignore index 03d58da46..9b0dea3b8 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,6 @@ clients/java/target/ clients/java/build.sbt clients/java/gradle.properties openapitools.json + +# Local MCP config +mcp.local.yaml diff --git a/crates/mcp/src/core/config.rs b/crates/mcp/src/core/config.rs index 3ceeeb659..302fc4623 100644 --- a/crates/mcp/src/core/config.rs +++ b/crates/mcp/src/core/config.rs @@ -292,6 +292,8 @@ pub enum BuiltinToolType { WebSearchPreview, /// Code interpreter tool (OpenAI: code_interpreter) CodeInterpreter, + /// Image generation tool (OpenAI: image_generation) + ImageGeneration, /// File search tool (OpenAI: file_search) FileSearch, } @@ -302,6 +304,7 @@ impl BuiltinToolType { match self { BuiltinToolType::WebSearchPreview => ResponseFormatConfig::WebSearchCall, BuiltinToolType::CodeInterpreter => ResponseFormatConfig::CodeInterpreterCall, + BuiltinToolType::ImageGeneration => ResponseFormatConfig::ImageGenerationCall, BuiltinToolType::FileSearch => ResponseFormatConfig::FileSearchCall, } } @@ -312,6 +315,7 @@ impl fmt::Display for BuiltinToolType { match self { BuiltinToolType::WebSearchPreview => write!(f, "web_search_preview"), BuiltinToolType::CodeInterpreter => write!(f, "code_interpreter"), + BuiltinToolType::ImageGeneration => write!(f, "image_generation"), BuiltinToolType::FileSearch => write!(f, "file_search"), } } @@ -341,6 +345,7 @@ pub enum ResponseFormatConfig { Passthrough, WebSearchCall, CodeInterpreterCall, + ImageGenerationCall, FileSearchCall, } @@ -1016,6 +1021,10 @@ tools: ResponseFormatConfig::CodeInterpreterCall, "\"code_interpreter_call\"", ), + ( + ResponseFormatConfig::ImageGenerationCall, + "\"image_generation_call\"", + ), (ResponseFormatConfig::FileSearchCall, "\"file_search_call\""), ]; @@ -1182,6 +1191,7 @@ policy: let types = vec![ (BuiltinToolType::WebSearchPreview, "\"web_search_preview\""), (BuiltinToolType::CodeInterpreter, "\"code_interpreter\""), + (BuiltinToolType::ImageGeneration, "\"image_generation\""), (BuiltinToolType::FileSearch, "\"file_search\""), ]; @@ -1204,6 +1214,10 @@ policy: BuiltinToolType::CodeInterpreter.response_format(), ResponseFormatConfig::CodeInterpreterCall ); + assert_eq!( + BuiltinToolType::ImageGeneration.response_format(), + ResponseFormatConfig::ImageGenerationCall + ); assert_eq!( BuiltinToolType::FileSearch.response_format(), ResponseFormatConfig::FileSearchCall @@ -1475,6 +1489,10 @@ servers: BuiltinToolType::CodeInterpreter.to_string(), "code_interpreter" ); + assert_eq!( + BuiltinToolType::ImageGeneration.to_string(), + "image_generation" + ); assert_eq!(BuiltinToolType::FileSearch.to_string(), "file_search"); } } diff --git a/crates/mcp/src/lib.rs b/crates/mcp/src/lib.rs index 6c218447a..2b6412ba9 100644 --- a/crates/mcp/src/lib.rs +++ b/crates/mcp/src/lib.rs @@ -56,4 +56,7 @@ pub use responses_bridge::{ }; pub use tenant::{SessionId, TenantContext, TenantId}; // Re-export from transform -pub use transform::{mcp_response_item_id, ResponseFormat, ResponseTransformer}; +pub use transform::{ + extract_image_generation_fallback_text, is_image_generation_error, mcp_response_item_id, + ResponseFormat, ResponseTransformer, +}; diff --git a/crates/mcp/src/transform/mod.rs b/crates/mcp/src/transform/mod.rs index 73fecb305..32a1bba81 100644 --- a/crates/mcp/src/transform/mod.rs +++ b/crates/mcp/src/transform/mod.rs @@ -22,5 +22,8 @@ mod transformer; mod types; -pub use transformer::{mcp_response_item_id, ResponseTransformer}; +pub use transformer::{ + extract_image_generation_fallback_text, is_image_generation_error, mcp_response_item_id, + ResponseTransformer, +}; pub use types::ResponseFormat; diff --git a/crates/mcp/src/transform/transformer.rs b/crates/mcp/src/transform/transformer.rs index 0bbf89312..5e53912f8 100644 --- a/crates/mcp/src/transform/transformer.rs +++ b/crates/mcp/src/transform/transformer.rs @@ -2,8 +2,10 @@ use openai_protocol::responses::{ CodeInterpreterCallStatus, CodeInterpreterOutput, FileSearchCallStatus, FileSearchResult, - ResponseOutputItem, WebSearchAction, WebSearchCallStatus, WebSearchSource, + ImageGenerationCallStatus, ResponseOutputItem, WebSearchAction, WebSearchCallStatus, + WebSearchSource, }; +use serde_json::Value; use super::ResponseFormat; @@ -26,15 +28,101 @@ pub fn mcp_response_item_id(source_id: &str) -> String { format!("mcp_{source_id}") } +/// Extract image-generation fallback text from JSON-RPC result content wrapper: +/// `{"result":{"content":[{"type":"text","text":"..."}]}}`. +pub fn extract_image_generation_fallback_text(value: &Value) -> Option { + value + .as_object() + .and_then(|obj| obj.get("result")) + .and_then(|v| v.as_object()) + .and_then(|obj| obj.get("content")) + .and_then(|v| v.as_array()) + .and_then(|content| { + content.iter().find_map(|item| { + item.as_object() + .filter(|o| o.get("type").and_then(|v| v.as_str()) == Some("text")) + .and_then(|o| o.get("text")) + .and_then(|v| v.as_str()) + .filter(|text| !text.trim().is_empty()) + .map(str::to_string) + }) + }) +} + +/// Read image-generation error status from JSON-RPC payload: +/// `result.isError` (defaults to `false` when missing). +pub fn is_image_generation_error(value: &Value) -> bool { + value + .as_object() + .and_then(|obj| obj.get("result")) + .and_then(|v| v.as_object()) + .and_then(|result_obj| result_obj.get("isError")) + .and_then(|v| v.as_bool()) + .unwrap_or(false) +} + /// Transforms MCP CallToolResult to OpenAI Responses API output items. pub struct ResponseTransformer; impl ResponseTransformer { + fn is_image_payload_candidate(obj: &serde_json::Map) -> bool { + obj.get("result").and_then(|v| v.as_str()).is_some() + } + + /// Extract image payload from MCP JSON-RPC `tools/call` response. + /// + /// Success example (`isError=false`), where `text` contains JSON with base64 result: + /// `{"result":{"content":[{"type":"text","text":"{\"result\":\"\"}"}],"isError":false}}` + /// + /// Error example (`isError=true`), where `text` is plain error text: + /// `{"result":{"content":[{"type":"text","text":"Error executing tool generate_image: ..."}],"isError":true}}` + /// + /// Any other shape is treated as unexpected and handled by the caller's + /// single fallback path. + fn image_payload_from_wrapped_content(result: &Value) -> Option { + // Parse the JSON-RPC wrapper object under top-level `result`. + let result_obj = result.get("result").and_then(|v| v.as_object())?; + + // `isError` determines whether we extract raw error text or image payload JSON. + let is_error = is_image_generation_error(result); + + // Error responses should preserve the raw text error payload. + if is_error { + return extract_image_generation_fallback_text(result).map(Value::String); + } + + let content = result_obj.get("content").and_then(|v| v.as_array())?; + for item in content { + let Some(obj) = item.as_object() else { + continue; + }; + // Only parse text content entries from MCP content blocks. + if obj.get("type").and_then(|v| v.as_str()) != Some("text") { + continue; + } + let Some(text) = obj.get("text").and_then(|v| v.as_str()) else { + continue; + }; + // Success payload is expected as JSON text in the content block. + if let Ok(parsed) = serde_json::from_str::(text) { + // Accept only image payload objects that include a string result. + if parsed + .as_object() + .is_some_and(Self::is_image_payload_candidate) + { + return Some(parsed); + } + } + } + + None + } + /// Transform an MCP result based on the configured response format. /// /// Returns a `ResponseOutputItem` from the protocols crate. pub fn transform( - result: &serde_json::Value, + result: &Value, format: &ResponseFormat, tool_call_id: &str, server_label: &str, @@ -49,13 +137,16 @@ impl ResponseTransformer { ResponseFormat::CodeInterpreterCall => { Self::to_code_interpreter_call(result, tool_call_id) } + ResponseFormat::ImageGenerationCall => { + Self::to_image_generation_call(result, tool_call_id) + } ResponseFormat::FileSearchCall => Self::to_file_search_call(result, tool_call_id), } } /// Transform to mcp_call output (passthrough). fn to_mcp_call( - result: &serde_json::Value, + result: &Value, tool_call_id: &str, server_label: &str, tool_name: &str, @@ -73,10 +164,10 @@ impl ResponseTransformer { } } - /// Flatten passthrough MCP results into the plain-string output shape used by OpenAI. - fn flatten_mcp_output(result: &serde_json::Value) -> String { + /// Flatten passthrough MCP results into plain text for OpenAI-compatible output. + fn flatten_mcp_output(result: &Value) -> String { match result { - serde_json::Value::String(text) => text.clone(), + Value::String(text) => text.clone(), _ => { let mut text_parts = Vec::new(); Self::collect_text_parts(result, &mut text_parts); @@ -89,14 +180,14 @@ impl ResponseTransformer { } } - fn collect_text_parts(value: &serde_json::Value, text_parts: &mut Vec) { + fn collect_text_parts(value: &Value, text_parts: &mut Vec) { match value { - serde_json::Value::Array(items) => { + Value::Array(items) => { for item in items { Self::collect_text_parts(item, text_parts); } } - serde_json::Value::Object(obj) => { + Value::Object(obj) => { if obj.get("type").and_then(|v| v.as_str()) == Some("text") { if let Some(text) = obj.get("text").and_then(|v| v.as_str()) { text_parts.push(text.to_string()); @@ -133,7 +224,7 @@ impl ResponseTransformer { } /// Transform MCP web search results to OpenAI web_search_call format. - fn to_web_search_call(result: &serde_json::Value, tool_call_id: &str) -> ResponseOutputItem { + fn to_web_search_call(result: &Value, tool_call_id: &str) -> ResponseOutputItem { let sources = Self::extract_web_sources(result); let queries = Self::extract_queries(result); @@ -149,10 +240,7 @@ impl ResponseTransformer { } /// Transform MCP code interpreter results to OpenAI code_interpreter_call format. - fn to_code_interpreter_call( - result: &serde_json::Value, - tool_call_id: &str, - ) -> ResponseOutputItem { + fn to_code_interpreter_call(result: &Value, tool_call_id: &str) -> ResponseOutputItem { let obj = result.as_object(); let container_id = obj @@ -177,8 +265,69 @@ impl ResponseTransformer { } } + /// Transform MCP image generation results to OpenAI image_generation_call format. + fn to_image_generation_call(result: &Value, tool_call_id: &str) -> ResponseOutputItem { + let payload = Self::image_payload_from_wrapped_content(result) + .unwrap_or_else(|| Value::String(Self::flatten_mcp_output(result))); + let parsed_payload = payload + .as_str() + .and_then(|s| serde_json::from_str::(s).ok()); + let obj = payload + .as_object() + .or_else(|| parsed_payload.as_ref().and_then(|v| v.as_object())); + + let status = ImageGenerationCallStatus::Completed; + let output_result = payload + .as_object() + .and_then(|obj| obj.get("result")) + .and_then(|v| v.as_str()) + .map(String::from) + .or_else(|| { + parsed_payload + .as_ref() + .and_then(|v| v.as_object()) + .and_then(|obj| obj.get("result")) + .and_then(|v| v.as_str()) + .map(String::from) + }) + .or_else(|| extract_image_generation_fallback_text(result)) + .or_else(|| payload.as_str().map(String::from)) + .or_else(|| result.as_str().map(String::from)) + .or_else(|| Some(payload.to_string())); + + ResponseOutputItem::ImageGenerationCall { + id: format!("ig_{tool_call_id}"), + status, + result: output_result, + revised_prompt: obj + .and_then(|o| o.get("revised_prompt")) + .and_then(|v| v.as_str()) + .map(String::from), + background: obj + .and_then(|o| o.get("background")) + .and_then(|v| v.as_str()) + .map(String::from), + output_format: obj + .and_then(|o| o.get("output_format")) + .and_then(|v| v.as_str()) + .map(String::from), + quality: obj + .and_then(|o| o.get("quality")) + .and_then(|v| v.as_str()) + .map(String::from), + size: obj + .and_then(|o| o.get("size")) + .and_then(|v| v.as_str()) + .map(String::from), + action: obj + .and_then(|o| o.get("action")) + .and_then(|v| v.as_str()) + .map(String::from), + } + } + /// Transform MCP file search results to OpenAI file_search_call format. - fn to_file_search_call(result: &serde_json::Value, tool_call_id: &str) -> ResponseOutputItem { + fn to_file_search_call(result: &Value, tool_call_id: &str) -> ResponseOutputItem { let obj = result.as_object(); let queries = Self::extract_queries(result); @@ -200,7 +349,7 @@ impl ResponseTransformer { } /// Extract web sources from MCP result. - fn extract_web_sources(result: &serde_json::Value) -> Vec { + fn extract_web_sources(result: &Value) -> Vec { let maybe_array = result.as_array().or_else(|| { result .as_object() @@ -214,7 +363,7 @@ impl ResponseTransformer { } /// Parse a single web source from JSON. - fn parse_web_source(item: &serde_json::Value) -> Option { + fn parse_web_source(item: &Value) -> Option { let obj = item.as_object()?; let url = obj.get("url").and_then(|v| v.as_str())?; Some(WebSearchSource { @@ -224,7 +373,7 @@ impl ResponseTransformer { } /// Extract queries from MCP result. - fn extract_queries(result: &serde_json::Value) -> Vec { + fn extract_queries(result: &Value) -> Vec { result .as_object() .and_then(|obj| obj.get("queries")) @@ -239,7 +388,7 @@ impl ResponseTransformer { } /// Extract code interpreter outputs from MCP result. - fn extract_code_outputs(result: &serde_json::Value) -> Vec { + fn extract_code_outputs(result: &Value) -> Vec { let mut outputs = Vec::new(); if let Some(obj) = result.as_object() { @@ -287,7 +436,7 @@ impl ResponseTransformer { } /// Extract file search results from MCP result. - fn extract_file_results(result: &serde_json::Value) -> Vec { + fn extract_file_results(result: &Value) -> Vec { result .as_object() .and_then(|obj| obj.get("results")) @@ -297,7 +446,7 @@ impl ResponseTransformer { } /// Parse a file search result from JSON. - fn parse_file_result(item: &serde_json::Value) -> Option { + fn parse_file_result(item: &Value) -> Option { let obj = item.as_object()?; let file_id = obj.get("file_id").and_then(|v| v.as_str())?.to_string(); let filename = obj.get("filename").and_then(|v| v.as_str())?.to_string(); @@ -316,7 +465,7 @@ impl ResponseTransformer { #[cfg(test)] mod tests { - use serde_json::json; + use serde_json::{json, to_value}; use super::*; @@ -638,4 +787,130 @@ mod tests { _ => panic!("Expected FileSearchCall"), } } + + #[test] + fn test_image_generation_transform_wrapped_content_extracts_metadata() { + let result = json!({ + "jsonrpc": "2.0", + "id": 3, + "result": { + "content": [ + { + "type": "text", + "text": "{\"result\":\"ZmFrZV9iYXNlNjQ=\",\"status\":\"completed\",\"action\":\"generate\",\"background\":\"opaque\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\",\"revised_prompt\":\"rp\"}" + } + ] + } + }); + + let transformed = ResponseTransformer::transform( + &result, + &ResponseFormat::ImageGenerationCall, + "req-1003", + "server", + "image_generation", + "{}", + ); + + match transformed { + ResponseOutputItem::ImageGenerationCall { + id, + status, + result, + action, + background, + output_format, + quality, + size, + revised_prompt, + } => { + assert_eq!(id, "ig_req-1003"); + assert_eq!(status, ImageGenerationCallStatus::Completed); + assert_eq!(result.as_deref(), Some("ZmFrZV9iYXNlNjQ=")); + assert_eq!(action.as_deref(), Some("generate")); + assert_eq!(background.as_deref(), Some("opaque")); + assert_eq!(output_format.as_deref(), Some("png")); + assert_eq!(quality.as_deref(), Some("high")); + assert_eq!(size.as_deref(), Some("1024x1024")); + assert_eq!(revised_prompt.as_deref(), Some("rp")); + } + _ => panic!("Expected ImageGenerationCall"), + } + } + + #[test] + fn test_image_generation_transform_output_shape_matches_dataplane() { + let result = json!({ + "jsonrpc": "2.0", + "id": 3, + "result": { + "content": [ + { + "type": "text", + "text": "{\"result\":\"ZmFrZV9iYXNlNjQ=\",\"action\":\"generate\",\"background\":\"opaque\",\"output_format\":\"png\",\"quality\":\"high\"}" + } + ], + "isError": false + } + }); + + let transformed = ResponseTransformer::transform( + &result, + &ResponseFormat::ImageGenerationCall, + "req-shape", + "server", + "image_generation", + "{}", + ); + + let item = to_value(&transformed).expect("image_generation_call should serialize"); + assert_eq!(item["type"], "image_generation_call"); + assert_eq!(item["id"], "ig_req-shape"); + assert_eq!(item["status"], "completed"); + assert_eq!(item["result"], "ZmFrZV9iYXNlNjQ="); + assert_eq!(item["action"], "generate"); + assert_eq!(item["background"], "opaque"); + assert_eq!(item["output_format"], "png"); + assert_eq!(item["quality"], "high"); + } + + #[test] + fn test_image_generation_transform_wrapped_content_skips_non_image_json_text() { + let result = json!({ + "jsonrpc": "2.0", + "id": 3, + "result": { + "content": [ + { + "type": "text", + "text": "{\"foo\":\"bar\",\"trace_id\":\"abc\"}" + }, + { + "type": "text", + "text": "{\"result\":\"ZmFrZV9iYXNlNjQ=\",\"status\":\"completed\"}" + } + ] + } + }); + + let transformed = ResponseTransformer::transform( + &result, + &ResponseFormat::ImageGenerationCall, + "req-1004", + "server", + "image_generation", + "{}", + ); + + match transformed { + ResponseOutputItem::ImageGenerationCall { + id, status, result, .. + } => { + assert_eq!(id, "ig_req-1004"); + assert_eq!(status, ImageGenerationCallStatus::Completed); + assert_eq!(result.as_deref(), Some("ZmFrZV9iYXNlNjQ=")); + } + _ => panic!("Expected ImageGenerationCall"), + } + } } diff --git a/crates/mcp/src/transform/types.rs b/crates/mcp/src/transform/types.rs index f508dc9d4..b64725dab 100644 --- a/crates/mcp/src/transform/types.rs +++ b/crates/mcp/src/transform/types.rs @@ -15,6 +15,8 @@ pub enum ResponseFormat { WebSearchCall, /// Transform to OpenAI code_interpreter_call format CodeInterpreterCall, + /// Transform to OpenAI image_generation_call format + ImageGenerationCall, /// Transform to OpenAI file_search_call format FileSearchCall, } @@ -25,6 +27,7 @@ impl From for ResponseFormat { ResponseFormatConfig::Passthrough => ResponseFormat::Passthrough, ResponseFormatConfig::WebSearchCall => ResponseFormat::WebSearchCall, ResponseFormatConfig::CodeInterpreterCall => ResponseFormat::CodeInterpreterCall, + ResponseFormatConfig::ImageGenerationCall => ResponseFormat::ImageGenerationCall, ResponseFormatConfig::FileSearchCall => ResponseFormat::FileSearchCall, } } @@ -43,6 +46,10 @@ mod tests { ResponseFormat::CodeInterpreterCall, "\"code_interpreter_call\"", ), + ( + ResponseFormat::ImageGenerationCall, + "\"image_generation_call\"", + ), (ResponseFormat::FileSearchCall, "\"file_search_call\""), ]; diff --git a/crates/protocols/src/event_types.rs b/crates/protocols/src/event_types.rs index fbf23e8c8..20f7db79f 100644 --- a/crates/protocols/src/event_types.rs +++ b/crates/protocols/src/event_types.rs @@ -263,6 +263,37 @@ impl fmt::Display for FileSearchCallEvent { } } +/// Image generation call events for streaming +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ImageGenerationCallEvent { + InProgress, + Generating, + PartialImage, + Completed, +} + +impl ImageGenerationCallEvent { + pub const IN_PROGRESS: &'static str = "response.image_generation_call.in_progress"; + pub const GENERATING: &'static str = "response.image_generation_call.generating"; + pub const PARTIAL_IMAGE: &'static str = "response.image_generation_call.partial_image"; + pub const COMPLETED: &'static str = "response.image_generation_call.completed"; + + pub const fn as_str(self) -> &'static str { + match self { + Self::InProgress => Self::IN_PROGRESS, + Self::Generating => Self::GENERATING, + Self::PartialImage => Self::PARTIAL_IMAGE, + Self::Completed => Self::COMPLETED, + } + } +} + +impl fmt::Display for ImageGenerationCallEvent { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + /// Item type discriminators used in output items #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum ItemType { @@ -274,6 +305,7 @@ pub enum ItemType { WebSearchCall, CodeInterpreterCall, FileSearchCall, + ImageGenerationCall, } impl ItemType { @@ -286,6 +318,7 @@ impl ItemType { pub const WEB_SEARCH_CALL: &'static str = "web_search_call"; pub const CODE_INTERPRETER_CALL: &'static str = "code_interpreter_call"; pub const FILE_SEARCH_CALL: &'static str = "file_search_call"; + pub const IMAGE_GENERATION_CALL: &'static str = "image_generation_call"; pub const fn as_str(self) -> &'static str { match self { @@ -297,6 +330,7 @@ impl ItemType { Self::WebSearchCall => Self::WEB_SEARCH_CALL, Self::CodeInterpreterCall => Self::CODE_INTERPRETER_CALL, Self::FileSearchCall => Self::FILE_SEARCH_CALL, + Self::ImageGenerationCall => Self::IMAGE_GENERATION_CALL, } } @@ -309,7 +343,10 @@ impl ItemType { pub const fn is_builtin_tool_call(self) -> bool { matches!( self, - Self::WebSearchCall | Self::CodeInterpreterCall | Self::FileSearchCall + Self::WebSearchCall + | Self::CodeInterpreterCall + | Self::FileSearchCall + | Self::ImageGenerationCall ) } } diff --git a/crates/protocols/src/responses.rs b/crates/protocols/src/responses.rs index ab2d3ddc6..22a9f5709 100644 --- a/crates/protocols/src/responses.rs +++ b/crates/protocols/src/responses.rs @@ -36,6 +36,10 @@ pub enum ResponseTool { #[serde(rename = "code_interpreter")] CodeInterpreter(CodeInterpreterTool), + /// Built-in tool. + #[serde(rename = "image_generation")] + ImageGeneration(ImageGenerationTool), + /// MCP server tool. #[serde(rename = "mcp")] Mcp(McpTool), @@ -80,6 +84,24 @@ pub struct CodeInterpreterTool { pub container: Option, } +/// Built-in image generation tool. +/// +/// Known fields are typed for validation; unknown fields are preserved for +/// forward compatibility via `extra`. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Deserialize, Serialize, Default, schemars::JsonSchema)] +pub struct ImageGenerationTool { + pub size: Option, + pub quality: Option, + pub background: Option, + pub output_format: Option, + pub output_compression: Option, + pub moderation: Option, + pub model: Option, + #[serde(flatten)] + pub extra: HashMap, +} + /// `require_approval` values. #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, schemars::JsonSchema)] #[serde(rename_all = "snake_case")] @@ -314,6 +336,25 @@ pub enum ResponseOutputItem { queries: Vec, results: Option>, }, + #[serde(rename = "image_generation_call")] + ImageGenerationCall { + id: String, + status: ImageGenerationCallStatus, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + revised_prompt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + background: Option, + #[serde(skip_serializing_if = "Option::is_none")] + output_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + quality: Option, + #[serde(skip_serializing_if = "Option::is_none")] + size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + action: Option, + }, } // ============================================================================ @@ -389,6 +430,16 @@ pub enum FileSearchCallStatus { Failed, } +/// Status for image generation tool calls. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum ImageGenerationCallStatus { + InProgress, + Generating, + Completed, + Failed, +} + /// A result from file search. #[serde_with::skip_serializing_none] #[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)] @@ -1454,9 +1505,83 @@ impl ResponseReasoningContent { #[cfg(test)] mod tests { use serde_json::json; + use validator::Validate; use super::*; + #[test] + fn deserialize_image_generation_tool_with_options() { + let tool: ResponseTool = serde_json::from_value(json!({ + "type": "image_generation", + "size": "1024x1024", + "quality": "high", + "custom_flag": true + })) + .expect("image_generation tool should deserialize"); + + match tool { + ResponseTool::ImageGeneration(t) => { + assert_eq!(t.size.as_deref(), Some("1024x1024")); + assert_eq!(t.quality.as_deref(), Some("high")); + assert_eq!(t.extra.get("custom_flag"), Some(&json!(true))); + } + other => panic!("expected image_generation tool, got {other:?}"), + } + } + + #[test] + fn validate_request_with_image_generation_allowed_tool_reference() { + let req: ResponsesRequest = serde_json::from_value(json!({ + "model": "gpt-4o-mini", + "input": "draw a robot", + "tools": [ + { "type": "image_generation", "size": "1024x1024" } + ], + "tool_choice": { + "type": "allowed_tools", + "mode": "auto", + "tools": [ + { "type": "image_generation" } + ] + } + })) + .expect("responses request should deserialize"); + + req.validate() + .expect("request with image_generation tool_choice should validate"); + + assert!(matches!( + req.tool_choice, + Some(ToolChoice::AllowedTools { .. }) + )); + } + + #[test] + fn deserialize_image_generation_output_item_with_extra_fields() { + let item: ResponseOutputItem = serde_json::from_value(json!({ + "id": "ig_123", + "type": "image_generation_call", + "status": "completed", + "action": "generate", + "output_format": "png", + "size": "1024x1024", + "result": "ZmFrZV9pbWFnZQ==", + "revised_prompt": "a red panda" + })) + .expect("image_generation_call output item should deserialize"); + + match item { + ResponseOutputItem::ImageGenerationCall { + id, status, result, .. + } => { + assert_eq!(id, "ig_123"); + assert_eq!(status, ImageGenerationCallStatus::Completed); + assert_eq!(result.as_deref(), Some("ZmFrZV9pbWFnZQ==")); + } + other => panic!("expected image_generation_call output item, got {other:?}"), + } + } + #[test] fn test_responses_request_omitted_top_p_deserializes_to_none() { let request: ResponsesRequest = serde_json::from_value(json!({ diff --git a/model_gateway/src/routers/grpc/common/responses/streaming.rs b/model_gateway/src/routers/grpc/common/responses/streaming.rs index b5300b2e1..ee1d572cb 100644 --- a/model_gateway/src/routers/grpc/common/responses/streaming.rs +++ b/model_gateway/src/routers/grpc/common/responses/streaming.rs @@ -8,7 +8,8 @@ use openai_protocol::{ common::{Usage, UsageInfo}, event_types::{ CodeInterpreterCallEvent, ContentPartEvent, FileSearchCallEvent, FunctionCallEvent, - McpEvent, OutputItemEvent, OutputTextEvent, ResponseEvent, WebSearchCallEvent, + ImageGenerationCallEvent, McpEvent, OutputItemEvent, OutputTextEvent, ResponseEvent, + WebSearchCallEvent, }, responses::{ ResponseOutputItem, ResponseStatus, ResponsesRequest, ResponsesResponse, ResponsesUsage, @@ -32,6 +33,7 @@ pub(crate) enum OutputItemType { WebSearchCall, CodeInterpreterCall, FileSearchCall, + ImageGenerationCall, } /// Status of an output item @@ -120,7 +122,8 @@ impl ResponseStreamEventEmitter { /// /// After MCP tools are executed, this updates the stored output items /// to include the output field from the tool results. - /// Supports mcp_call, web_search_call, code_interpreter_call, and file_search_call item types. + /// Supports mcp_call, web_search_call, code_interpreter_call, file_search_call, + /// and image_generation_call item types. pub(crate) fn update_mcp_call_outputs(&mut self, tool_results: &[ToolResult]) { for tool_result in tool_results { // Find the output item with matching call_id @@ -134,6 +137,7 @@ impl ResponseStreamEventEmitter { | Some("web_search_call") | Some("code_interpreter_call") | Some("file_search_call") + | Some("image_generation_call") ); if is_tool_call && item_data.get("call_id").and_then(|c| c.as_str()) @@ -475,6 +479,7 @@ impl ResponseStreamEventEmitter { ResponseFormat::WebSearchCall => WebSearchCallEvent::IN_PROGRESS, ResponseFormat::CodeInterpreterCall => CodeInterpreterCallEvent::IN_PROGRESS, ResponseFormat::FileSearchCall => FileSearchCallEvent::IN_PROGRESS, + ResponseFormat::ImageGenerationCall => ImageGenerationCallEvent::IN_PROGRESS, ResponseFormat::Passthrough => McpEvent::CALL_IN_PROGRESS, }; self.emit_tool_event(event_type, output_index, item_id) @@ -491,6 +496,7 @@ impl ResponseStreamEventEmitter { ResponseFormat::WebSearchCall => WebSearchCallEvent::SEARCHING, ResponseFormat::CodeInterpreterCall => CodeInterpreterCallEvent::INTERPRETING, ResponseFormat::FileSearchCall => FileSearchCallEvent::SEARCHING, + ResponseFormat::ImageGenerationCall => ImageGenerationCallEvent::GENERATING, ResponseFormat::Passthrough => return None, }; Some(self.emit_tool_event(event_type, output_index, item_id)) @@ -507,6 +513,7 @@ impl ResponseStreamEventEmitter { ResponseFormat::WebSearchCall => WebSearchCallEvent::COMPLETED, ResponseFormat::CodeInterpreterCall => CodeInterpreterCallEvent::COMPLETED, ResponseFormat::FileSearchCall => FileSearchCallEvent::COMPLETED, + ResponseFormat::ImageGenerationCall => ImageGenerationCallEvent::COMPLETED, ResponseFormat::Passthrough => McpEvent::CALL_COMPLETED, }; self.emit_tool_event(event_type, output_index, item_id) @@ -522,6 +529,7 @@ impl ResponseStreamEventEmitter { Some(ResponseFormat::WebSearchCall) => "web_search_call", Some(ResponseFormat::CodeInterpreterCall) => "code_interpreter_call", Some(ResponseFormat::FileSearchCall) => "file_search_call", + Some(ResponseFormat::ImageGenerationCall) => "image_generation_call", Some(ResponseFormat::Passthrough) => "mcp_call", None => "function_call", } @@ -533,6 +541,7 @@ impl ResponseStreamEventEmitter { Some(ResponseFormat::WebSearchCall) => OutputItemType::WebSearchCall, Some(ResponseFormat::CodeInterpreterCall) => OutputItemType::CodeInterpreterCall, Some(ResponseFormat::FileSearchCall) => OutputItemType::FileSearchCall, + Some(ResponseFormat::ImageGenerationCall) => OutputItemType::ImageGenerationCall, Some(ResponseFormat::Passthrough) => OutputItemType::McpCall, None => OutputItemType::FunctionCall, } @@ -626,6 +635,7 @@ impl ResponseStreamEventEmitter { OutputItemType::WebSearchCall => "ws", OutputItemType::CodeInterpreterCall => "ci", OutputItemType::FileSearchCall => "fs", + OutputItemType::ImageGenerationCall => "ig", }; let id = Self::generate_item_id(id_prefix); diff --git a/model_gateway/src/routers/grpc/common/responses/utils.rs b/model_gateway/src/routers/grpc/common/responses/utils.rs index f00f37d0a..0eee7a9b5 100644 --- a/model_gateway/src/routers/grpc/common/responses/utils.rs +++ b/model_gateway/src/routers/grpc/common/responses/utils.rs @@ -44,7 +44,9 @@ pub(crate) async fn ensure_mcp_connection( t.iter().any(|tool| { matches!( tool, - ResponseTool::WebSearchPreview(_) | ResponseTool::CodeInterpreter(_) + ResponseTool::WebSearchPreview(_) + | ResponseTool::CodeInterpreter(_) + | ResponseTool::ImageGeneration(_) ) }) }) diff --git a/model_gateway/src/routers/grpc/harmony/builder.rs b/model_gateway/src/routers/grpc/harmony/builder.rs index 14d874b4a..e5be29e66 100644 --- a/model_gateway/src/routers/grpc/harmony/builder.rs +++ b/model_gateway/src/routers/grpc/harmony/builder.rs @@ -60,7 +60,12 @@ pub(crate) fn convert_harmony_logprobs(proto_logprobs: &ProtoOutputLogProbs) -> } /// Built-in tools that are added to the system message -const BUILTIN_TOOLS: &[&str] = &["web_search_preview", "code_interpreter", "container"]; +const BUILTIN_TOOLS: &[&str] = &[ + "web_search_preview", + "code_interpreter", + "image_generation", + "container", +]; /// Trait for tool-like objects that can be converted to Harmony ToolDescription trait ToolLike { @@ -80,7 +85,7 @@ impl ToolLike for Tool { fn is_builtin(&self) -> bool { matches!( self.tool_type.as_str(), - "web_search_preview" | "code_interpreter" | "container" + "web_search_preview" | "code_interpreter" | "image_generation" | "container" ) } @@ -102,7 +107,9 @@ impl ToolLike for ResponseTool { fn is_builtin(&self) -> bool { matches!( self, - ResponseTool::WebSearchPreview(_) | ResponseTool::CodeInterpreter(_) + ResponseTool::WebSearchPreview(_) + | ResponseTool::CodeInterpreter(_) + | ResponseTool::ImageGeneration(_) ) } @@ -425,6 +432,7 @@ impl HarmonyBuilder { ResponseTool::Function(_) => "function", ResponseTool::WebSearchPreview(_) => "web_search_preview", ResponseTool::CodeInterpreter(_) => "code_interpreter", + ResponseTool::ImageGeneration(_) => "image_generation", ResponseTool::Mcp(_) => "mcp", }) .collect() diff --git a/model_gateway/src/routers/grpc/regular/responses/common.rs b/model_gateway/src/routers/grpc/regular/responses/common.rs index 9ac942fb1..7f1074f51 100644 --- a/model_gateway/src/routers/grpc/regular/responses/common.rs +++ b/model_gateway/src/routers/grpc/regular/responses/common.rs @@ -15,13 +15,17 @@ use openai_protocol::{ ResponsesRequest, }, }; +use serde_json::Value; use smg_data_connector::{ self as data_connector, ConversationId, ResponseId, ResponseStorageError, }; -use smg_mcp::McpToolSession; +use smg_mcp::{McpToolSession, ResponseFormat}; use tracing::{debug, warn}; -use crate::routers::{error, grpc::common::responses::ResponsesContext}; +use crate::routers::{ + error, grpc::common::responses::ResponsesContext, + tool_output_context::compact_tool_output_for_model_context, +}; // ============================================================================ // Tool Loop State @@ -61,9 +65,14 @@ impl ToolLoopState { tool_name: String, args_json_str: String, output_str: String, + response_format: ResponseFormat, output_item: ResponseOutputItem, - _success: bool, ) { + let output_value = + serde_json::from_str::(&output_str).unwrap_or(Value::String(output_str)); + let model_context_output = + compact_tool_output_for_model_context(&response_format, &output_value); + // Add function_tool_call item with both arguments and output let id = call_id.clone(); self.conversation_history @@ -72,7 +81,7 @@ impl ToolLoopState { call_id, name: tool_name, arguments: args_json_str, - output: Some(output_str), + output: Some(model_context_output), status: Some("completed".to_string()), }); diff --git a/model_gateway/src/routers/grpc/regular/responses/non_streaming.rs b/model_gateway/src/routers/grpc/regular/responses/non_streaming.rs index d409b878a..ff2f42f9f 100644 --- a/model_gateway/src/routers/grpc/regular/responses/non_streaming.rs +++ b/model_gateway/src/routers/grpc/regular/responses/non_streaming.rs @@ -375,8 +375,8 @@ pub(super) async fn execute_tool_loop( result.tool_name, result.arguments_str, output_str, + result.response_format, output_item, - !result.is_error, ); // Increment total calls counter diff --git a/model_gateway/src/routers/grpc/regular/responses/streaming.rs b/model_gateway/src/routers/grpc/regular/responses/streaming.rs index e544e5806..a2ff51db6 100644 --- a/model_gateway/src/routers/grpc/regular/responses/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/responses/streaming.rs @@ -801,8 +801,8 @@ async fn execute_tool_loop_streaming_internal( tool_output.tool_name, tool_output.arguments_str, output_str, + tool_output.response_format, output_item, - success, ); } diff --git a/model_gateway/src/routers/mcp_utils.rs b/model_gateway/src/routers/mcp_utils.rs index d40d5d687..2ec75d0b9 100644 --- a/model_gateway/src/routers/mcp_utils.rs +++ b/model_gateway/src/routers/mcp_utils.rs @@ -114,7 +114,7 @@ pub async fn connect_mcp_servers( /// Routing information for a built-in tool type. /// -/// When a built-in tool type (web_search_preview, code_interpreter, file_search) +/// When a built-in tool type (web_search_preview, code_interpreter, image_generation, file_search) /// is configured to route to an MCP server, this struct holds the routing details. #[derive(Debug, Clone)] pub struct BuiltinToolRouting { @@ -130,7 +130,7 @@ pub struct BuiltinToolRouting { /// Collect routing information for built-in tools in a request. /// -/// Scans request tools for built-in types (web_search_preview, code_interpreter, file_search) +/// Scans request tools for built-in types (web_search_preview, code_interpreter, image_generation, file_search) /// and looks up configured MCP servers to handle them. /// /// # Arguments @@ -154,6 +154,7 @@ pub fn collect_builtin_routing( let builtin_type = match tool { ResponseTool::WebSearchPreview(_) => BuiltinToolType::WebSearchPreview, ResponseTool::CodeInterpreter(_) => BuiltinToolType::CodeInterpreter, + ResponseTool::ImageGeneration(_) => BuiltinToolType::ImageGeneration, _ => continue, }; @@ -194,6 +195,7 @@ pub fn extract_builtin_types(tools: &[ResponseTool]) -> Vec { .filter_map(|t| match t { ResponseTool::WebSearchPreview(_) => Some(BuiltinToolType::WebSearchPreview), ResponseTool::CodeInterpreter(_) => Some(BuiltinToolType::CodeInterpreter), + ResponseTool::ImageGeneration(_) => Some(BuiltinToolType::ImageGeneration), _ => None, }) .collect() @@ -277,7 +279,8 @@ mod tests { use openai_protocol::{ common::Function, responses::{ - CodeInterpreterTool, FunctionTool, McpTool, ResponseTool, WebSearchPreviewTool, + CodeInterpreterTool, FunctionTool, ImageGenerationTool, McpTool, ResponseTool, + WebSearchPreviewTool, }, }; use serde_json::json; @@ -463,6 +466,25 @@ mod tests { builtin_type: Some(BuiltinToolType::CodeInterpreter), builtin_tool_name: Some("run_code".to_string()), }, + McpServerConfig { + name: "image-server".to_string(), + transport: McpTransport::Streamable { + url: "http://localhost:9997/image".to_string(), + token: None, + headers: HashMap::new(), + }, + proxy: None, + required: false, + tools: Some(HashMap::from([( + "generate_image".to_string(), + ToolConfig { + response_format: ResponseFormatConfig::ImageGenerationCall, + ..Default::default() + }, + )])), + builtin_type: Some(BuiltinToolType::ImageGeneration), + builtin_tool_name: Some("generate_image".to_string()), + }, ], pool: Default::default(), proxy: None, @@ -476,11 +498,12 @@ mod tests { let tools = vec![ ResponseTool::WebSearchPreview(WebSearchPreviewTool::default()), ResponseTool::CodeInterpreter(CodeInterpreterTool::default()), + ResponseTool::ImageGeneration(ImageGenerationTool::default()), ]; let routing = collect_builtin_routing(&orchestrator, Some(&tools)); - assert_eq!(routing.len(), 2); + assert_eq!(routing.len(), 3); // Find web search routing let web_routing = routing @@ -502,6 +525,32 @@ mod tests { code_routing.response_format, ResponseFormat::CodeInterpreterCall ); + + // Find image generation routing + let image_routing = routing + .iter() + .find(|r| r.builtin_type == BuiltinToolType::ImageGeneration) + .expect("Should have image generation routing"); + assert_eq!(image_routing.server_name, "image-server"); + assert_eq!(image_routing.tool_name, "generate_image"); + assert_eq!( + image_routing.response_format, + ResponseFormat::ImageGenerationCall + ); + } + + #[test] + fn test_extract_builtin_types_includes_image_generation() { + let tools = vec![ + ResponseTool::WebSearchPreview(WebSearchPreviewTool::default()), + ResponseTool::CodeInterpreter(CodeInterpreterTool::default()), + ResponseTool::ImageGeneration(ImageGenerationTool::default()), + ]; + + let types = extract_builtin_types(&tools); + assert!(types.contains(&BuiltinToolType::WebSearchPreview)); + assert!(types.contains(&BuiltinToolType::CodeInterpreter)); + assert!(types.contains(&BuiltinToolType::ImageGeneration)); } // ========================================================================= diff --git a/model_gateway/src/routers/mod.rs b/model_gateway/src/routers/mod.rs index 21594aff3..d6fa648a6 100644 --- a/model_gateway/src/routers/mod.rs +++ b/model_gateway/src/routers/mod.rs @@ -41,6 +41,7 @@ pub mod persistence_utils; pub mod responses; pub mod router_manager; pub mod tokenize; +mod tool_output_context; pub mod worker_selection; pub use factory::RouterFactory; diff --git a/model_gateway/src/routers/openai/mcp/tool_loop.rs b/model_gateway/src/routers/openai/mcp/tool_loop.rs index 41da9a9d3..db96c2906 100644 --- a/model_gateway/src/routers/openai/mcp/tool_loop.rs +++ b/model_gateway/src/routers/openai/mcp/tool_loop.rs @@ -14,10 +14,10 @@ use axum::http::HeaderMap; use bytes::Bytes; use openai_protocol::{ event_types::{ - is_function_call_type, CodeInterpreterCallEvent, FileSearchCallEvent, ItemType, McpEvent, - OutputItemEvent, WebSearchCallEvent, + is_function_call_type, CodeInterpreterCallEvent, FileSearchCallEvent, + ImageGenerationCallEvent, ItemType, McpEvent, OutputItemEvent, WebSearchCallEvent, }, - responses::{generate_id, ResponseInput, ResponsesRequest}, + responses::{generate_id, ResponseInput, ResponseTool, ResponsesRequest}, }; use serde_json::{json, to_value, Value}; use smg_mcp::{ @@ -29,7 +29,10 @@ use tracing::{debug, info, warn}; use super::tool_handler::FunctionCallInProgress; use crate::{ observability::metrics::{metrics_labels, Metrics}, - routers::{error, header_utils::ApiProvider, mcp_utils::DEFAULT_MAX_ITERATIONS}, + routers::{ + error, header_utils::ApiProvider, mcp_utils::DEFAULT_MAX_ITERATIONS, + tool_output_context::compact_tool_output_for_model_context, + }, }; /// State for tracking multi-turn tool calling loop @@ -96,7 +99,7 @@ pub(crate) async fn execute_streaming_tool_calls( tx: &mpsc::UnboundedSender>, state: &mut ToolLoopState, sequence_number: &mut u64, - model_id: &str, + original_body: &ResponsesRequest, ) -> bool { for call in pending_calls { if call.name.is_empty() { @@ -121,7 +124,7 @@ pub(crate) async fn execute_streaming_tool_calls( let response_format = session.tool_response_format(&call.name); let server_label = session.resolve_tool_server_label(&call.name); - let arguments: Value = match serde_json::from_str(args_str) { + let mut arguments: Value = match serde_json::from_str(args_str) { Ok(v) => v, Err(e) => { let err_str = format!("Failed to parse tool arguments: {e}"); @@ -160,12 +163,12 @@ pub(crate) async fn execute_streaming_tool_calls( continue; } }; - + apply_request_tool_overrides(&response_format, original_body, &mut arguments); if !send_tool_call_intermediate_event(tx, &call, &response_format, sequence_number) { return false; } - debug!("Calling MCP tool '{}' with args: {}", call.name, args_str); + debug!("Calling MCP tool '{}' with args: {}", call.name, arguments); let tool_output = session .execute_tool(ToolExecutionInput { call_id: call.call_id.clone(), @@ -174,9 +177,13 @@ pub(crate) async fn execute_streaming_tool_calls( }) .await; - Metrics::record_mcp_tool_duration(model_id, &tool_output.tool_name, tool_output.duration); + Metrics::record_mcp_tool_duration( + &original_body.model, + &tool_output.tool_name, + tool_output.duration, + ); Metrics::record_mcp_tool_call( - model_id, + &original_body.model, &tool_output.tool_name, if tool_output.is_error { metrics_labels::RESULT_ERROR @@ -185,7 +192,8 @@ pub(crate) async fn execute_streaming_tool_calls( }, ); - let output_str = tool_output.output.to_string(); + let model_context_output = + compact_tool_output_for_model_context(&response_format, &tool_output.output); let mut mcp_call_item = to_value(tool_output.to_response_item()).unwrap_or_else(|e| { warn!(tool = %call.name, error = %e, "Failed to convert item to Value"); json!({}) @@ -210,8 +218,8 @@ pub(crate) async fn execute_streaming_tool_calls( state.record_call( call.call_id, call.name, - call.arguments_buffer, - output_str, + tool_output.arguments_str.clone(), + model_context_output, mcp_call_item, ); } @@ -253,6 +261,58 @@ pub(crate) fn prepare_mcp_tools_as_functions(payload: &mut Value, session: &McpT } } +/// Extract request-level builtin tool overrides to merge into tool-call arguments. +/// +/// Currently this is intentionally scoped to image generation only. +/// We can extend this to other builtin tools later if needed. +fn request_tool_overrides( + response_format: &ResponseFormat, + original_body: &ResponsesRequest, +) -> Option { + if !matches!(response_format, ResponseFormat::ImageGenerationCall) { + return None; + } + + // Read request-defined tools and find the image_generation config. + let tools = original_body.tools.as_ref()?; + + tools.iter().find_map(|tool| { + // Serialize image tool config into a JSON object for merge. + let mut serialized = match tool { + ResponseTool::ImageGeneration(image_tool) => match to_value(image_tool).ok()? { + Value::Object(obj) => obj, + _ => return None, + }, + _ => return None, + }; + // Drop nulls so absent fields do not overwrite generated call arguments. + serialized.retain(|_, v| !v.is_null()); + if serialized.is_empty() { + None + } else { + Some(Value::Object(serialized)) + } + }) +} + +fn apply_request_tool_overrides( + response_format: &ResponseFormat, + original_body: &ResponsesRequest, + arguments: &mut Value, +) { + if let (Some(overrides), Some(args_obj)) = ( + request_tool_overrides(response_format, original_body), + arguments.as_object_mut(), + ) { + let Some(override_obj) = overrides.as_object() else { + return; + }; + for (k, v) in override_obj { + args_obj.insert(k.clone(), v.clone()); + } + } +} + /// Build a resume payload with conversation history pub(crate) fn build_resume_payload( base_payload: &Value, @@ -399,11 +459,12 @@ fn send_tool_call_intermediate_event( response_format: &ResponseFormat, sequence_number: &mut u64, ) -> bool { - // Determine event type and ID prefix based on response format + // Determine event type based on response format let event_type = match response_format { ResponseFormat::WebSearchCall => WebSearchCallEvent::SEARCHING, ResponseFormat::CodeInterpreterCall => CodeInterpreterCallEvent::INTERPRETING, ResponseFormat::FileSearchCall => FileSearchCallEvent::SEARCHING, + ResponseFormat::ImageGenerationCall => ImageGenerationCallEvent::GENERATING, ResponseFormat::Passthrough => return true, // mcp_call has no intermediate event }; @@ -446,6 +507,7 @@ fn send_tool_call_completion_events( ItemType::WEB_SEARCH_CALL => WebSearchCallEvent::COMPLETED, ItemType::CODE_INTERPRETER_CALL => CodeInterpreterCallEvent::COMPLETED, ItemType::FILE_SEARCH_CALL => FileSearchCallEvent::COMPLETED, + ItemType::IMAGE_GENERATION_CALL => ImageGenerationCallEvent::COMPLETED, _ => McpEvent::CALL_COMPLETED, // Default to mcp_call for mcp_call and unknown types }; @@ -491,6 +553,7 @@ fn stable_streaming_tool_item_id( ResponseFormat::WebSearchCall => normalize_tool_item_id_with_prefix(source_id, "ws_"), ResponseFormat::CodeInterpreterCall => normalize_tool_item_id_with_prefix(source_id, "ci_"), ResponseFormat::FileSearchCall => normalize_tool_item_id_with_prefix(source_id, "fs_"), + ResponseFormat::ImageGenerationCall => normalize_tool_item_id_with_prefix(source_id, "ig_"), } } @@ -511,7 +574,8 @@ fn non_streaming_tool_item_id_source(item_id: &str, response_format: &ResponseFo ResponseFormat::Passthrough => item_id.to_string(), ResponseFormat::WebSearchCall | ResponseFormat::CodeInterpreterCall - | ResponseFormat::FileSearchCall => item_id + | ResponseFormat::FileSearchCall + | ResponseFormat::ImageGenerationCall => item_id .strip_prefix("fc_") .or_else(|| item_id.strip_prefix("call_")) .unwrap_or(item_id) @@ -622,6 +686,7 @@ pub(crate) async fn execute_tool_loop( for call in function_calls { state.total_calls += 1; + let response_format = session.tool_response_format(&call.name); if state.total_calls > effective_limit { warn!( @@ -636,12 +701,11 @@ pub(crate) async fn execute_tool_loop( original_body, ); } - let arguments: Value = match serde_json::from_str(&call.arguments) { + let mut arguments: Value = match serde_json::from_str(&call.arguments) { Ok(v) => v, Err(e) => { warn!(tool = %call.name, error = %e, "Failed to parse tool arguments as JSON"); let error_output = format!("Invalid tool arguments: {e}"); - let response_format = session.tool_response_format(&call.name); let server_label = session.resolve_tool_server_label(&call.name); let tool_item_id = non_streaming_tool_item_id_source(&call.item_id, &response_format); @@ -671,11 +735,8 @@ pub(crate) async fn execute_tool_loop( continue; } }; - - debug!( - "Calling MCP tool '{}' with args: {}", - call.name, call.arguments - ); + apply_request_tool_overrides(&response_format, original_body, &mut arguments); + debug!("Calling MCP tool '{}' with args: {}", call.name, arguments); let tool_output = session .execute_tool(ToolExecutionInput { call_id: call.call_id.clone(), @@ -699,8 +760,9 @@ pub(crate) async fn execute_tool_loop( }, ); - let output_str = tool_output.output.to_string(); let response_format = session.tool_response_format(&call.name); + let model_context_output = + compact_tool_output_for_model_context(&response_format, &tool_output.output); let server_label = session.resolve_tool_server_label(&call.name); let tool_item_id = non_streaming_tool_item_id_source(&call.item_id, &response_format); let transformed_item = build_transformed_mcp_call_item( @@ -709,14 +771,14 @@ pub(crate) async fn execute_tool_loop( &tool_item_id, &server_label, &call.name, - &call.arguments, + &tool_output.arguments_str, ); state.record_call( call.call_id, call.name, - call.arguments, - output_str, + tool_output.arguments_str.clone(), + model_context_output, transformed_item, ); } diff --git a/model_gateway/src/routers/openai/responses/streaming.rs b/model_gateway/src/routers/openai/responses/streaming.rs index 47d0f59c9..7142f9202 100644 --- a/model_gateway/src/routers/openai/responses/streaming.rs +++ b/model_gateway/src/routers/openai/responses/streaming.rs @@ -19,7 +19,8 @@ use futures_util::StreamExt; use openai_protocol::{ event_types::{ is_function_call_type, is_response_event, CodeInterpreterCallEvent, FileSearchCallEvent, - FunctionCallEvent, ItemType, McpEvent, OutputItemEvent, ResponseEvent, WebSearchCallEvent, + FunctionCallEvent, ImageGenerationCallEvent, ItemType, McpEvent, OutputItemEvent, + ResponseEvent, WebSearchCallEvent, }, responses::{ResponseTool, ResponsesRequest}, }; @@ -147,6 +148,9 @@ pub(super) fn apply_event_transformations_inplace( // Determine item type and ID prefix based on response_format let (new_type, id_prefix) = match response_format { ResponseFormat::WebSearchCall => (ItemType::WEB_SEARCH_CALL, "ws_"), + ResponseFormat::ImageGenerationCall => { + (ItemType::IMAGE_GENERATION_CALL, "ig_") + } _ => (ItemType::MCP_CALL, "mcp_"), }; @@ -429,6 +433,7 @@ fn maybe_inject_tool_in_progress( ItemType::WEB_SEARCH_CALL => WebSearchCallEvent::IN_PROGRESS, ItemType::CODE_INTERPRETER_CALL => CodeInterpreterCallEvent::IN_PROGRESS, ItemType::FILE_SEARCH_CALL => FileSearchCallEvent::IN_PROGRESS, + ItemType::IMAGE_GENERATION_CALL => ImageGenerationCallEvent::IN_PROGRESS, _ => return true, // Not a tool call item, nothing to inject }; @@ -967,7 +972,7 @@ pub(super) fn handle_streaming_with_tool_interception( &tx, &mut state, &mut sequence_number, - &original_request.model, + &original_request, ) .await { diff --git a/model_gateway/src/routers/openai/responses/utils.rs b/model_gateway/src/routers/openai/responses/utils.rs index 7c6e374bd..e54aa2b9f 100644 --- a/model_gateway/src/routers/openai/responses/utils.rs +++ b/model_gateway/src/routers/openai/responses/utils.rs @@ -230,6 +230,7 @@ pub(super) fn response_tool_to_value(tool: &ResponseTool) -> Option { } ResponseTool::WebSearchPreview(_) => serde_json::to_value(tool).ok(), ResponseTool::CodeInterpreter(_) => serde_json::to_value(tool).ok(), + ResponseTool::ImageGeneration(_) => serde_json::to_value(tool).ok(), ResponseTool::Function(_) => None, } } diff --git a/model_gateway/src/routers/tool_output_context.rs b/model_gateway/src/routers/tool_output_context.rs new file mode 100644 index 000000000..2f01bde00 --- /dev/null +++ b/model_gateway/src/routers/tool_output_context.rs @@ -0,0 +1,36 @@ +use serde_json::{json, Value}; +use smg_mcp::{extract_image_generation_fallback_text, is_image_generation_error, ResponseFormat}; + +/// Build tool output text for model context. +/// +/// This is format-driven and intended to support per-tool compaction policies. +/// Currently only `ResponseFormat::ImageGenerationCall` is compacted into a +/// minimal fixed summary (no payload/status) to avoid feeding large binary +/// image data back into the next model turn. Other formats are currently no-op +/// and return `output.to_string()` unchanged. +pub fn compact_tool_output_for_model_context( + response_format: &ResponseFormat, + output: &Value, +) -> String { + match response_format { + ResponseFormat::ImageGenerationCall => { + let is_error = is_image_generation_error(output); + let note = if is_error { + extract_image_generation_fallback_text(output).unwrap_or_default() + } else { + "Successfully generated the image".to_string() + }; + let summary = json!({ + "tool": "generate_image", + "status": if is_error { "failed" } else { "completed" }, + "note": note + }); + summary.to_string() + } + // No-op for other tools for now: preserve raw string outputs as-is. + _ => match output { + Value::String(text) => text.clone(), + _ => output.to_string(), + }, + } +}