From a801a609bb671bcfcd07bea0c24665a946a160d6 Mon Sep 17 00:00:00 2001 From: wgqqqqq Date: Mon, 9 Mar 2026 10:45:28 +0800 Subject: [PATCH 1/4] feat: add OpenAI Responses API support --- .../core/src/agentic/image_analysis/types.rs | 2 +- .../ai/ai_stream_handlers/src/lib.rs | 1 + .../src/stream_handler/mod.rs | 4 +- .../src/stream_handler/responses.rs | 266 ++++++++++++++++++ .../ai/ai_stream_handlers/src/types/mod.rs | 3 +- .../ai_stream_handlers/src/types/responses.rs | 172 +++++++++++ .../core/src/infrastructure/ai/client.rs | 183 +++++++++++- .../ai/providers/openai/message_converter.rs | 214 ++++++++++++++ src/crates/core/src/util/types/config.rs | 38 +++ .../components/steps/ModelConfigStep.tsx | 9 +- .../onboarding/store/onboardingStore.ts | 2 +- .../config/components/AIModelConfig.tsx | 22 +- .../config/schemas/ai-models.json | 5 + .../config/services/modelConfigs.ts | 2 + .../src/locales/en-US/settings/ai-model.json | 1 + .../src/locales/zh-CN/settings/ai-model.json | 5 +- src/web-ui/src/shared/types/chat.ts | 2 +- 17 files changed, 914 insertions(+), 17 deletions(-) create mode 100644 src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/responses.rs create mode 100644 src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/responses.rs diff --git a/src/crates/core/src/agentic/image_analysis/types.rs b/src/crates/core/src/agentic/image_analysis/types.rs index 2037495d..f1d520c4 100644 --- a/src/crates/core/src/agentic/image_analysis/types.rs +++ b/src/crates/core/src/agentic/image_analysis/types.rs @@ -117,7 +117,7 @@ impl ImageLimits { /// Get limits based on model provider pub fn for_provider(provider: &str) -> Self { match provider.to_lowercase().as_str() { - "openai" => Self { + "openai" | "response" | "responses" => Self { max_size: 20 * 1024 * 1024, // 20MB max_width: 2048, max_height: 2048, diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/lib.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/lib.rs index aa8ed3b7..2aed04b0 100644 --- a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/lib.rs +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/lib.rs @@ -3,4 +3,5 @@ mod types; pub use stream_handler::handle_anthropic_stream; pub use stream_handler::handle_openai_stream; +pub use stream_handler::handle_responses_stream; pub use types::unified::{UnifiedResponse, UnifiedTokenUsage, UnifiedToolCall}; diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/mod.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/mod.rs index a3f2f220..b117f3c7 100644 --- a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/mod.rs +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/mod.rs @@ -1,5 +1,7 @@ mod openai; mod anthropic; +mod responses; pub use openai::handle_openai_stream; -pub use anthropic::handle_anthropic_stream; \ No newline at end of file +pub use anthropic::handle_anthropic_stream; +pub use responses::handle_responses_stream; diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/responses.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/responses.rs new file mode 100644 index 00000000..646968c0 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/responses.rs @@ -0,0 +1,266 @@ +use crate::types::responses::{ + parse_responses_output_item, ResponsesCompleted, ResponsesDone, ResponsesStreamEvent, +}; +use crate::types::unified::UnifiedResponse; +use anyhow::{anyhow, Result}; +use eventsource_stream::Eventsource; +use futures::StreamExt; +use log::{error, trace}; +use reqwest::Response; +use serde_json::Value; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::timeout; + +fn extract_api_error_message(event_json: &Value) -> Option { + let response = event_json.get("response")?; + let error = response.get("error")?; + + if error.is_null() { + return None; + } + + if let Some(message) = error.get("message").and_then(Value::as_str) { + return Some(message.to_string()); + } + if let Some(message) = error.as_str() { + return Some(message.to_string()); + } + + Some("An error occurred during responses streaming".to_string()) +} + +pub async fn handle_responses_stream( + response: Response, + tx_event: mpsc::UnboundedSender>, + tx_raw_sse: Option>, +) { + let mut stream = response.bytes_stream().eventsource(); + let idle_timeout = Duration::from_secs(600); + let received_completion = false; + let mut received_text_delta = false; + + loop { + let sse_event = timeout(idle_timeout, stream.next()).await; + let sse = match sse_event { + Ok(Some(Ok(sse))) => sse, + Ok(None) => { + if received_completion { + return; + } + let error_msg = "Responses SSE stream closed before response completed"; + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Ok(Some(Err(e))) => { + let error_msg = format!("Responses SSE stream error: {}", e); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Err(_) => { + let error_msg = format!( + "Responses SSE stream timeout after {}s", + idle_timeout.as_secs() + ); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + let raw = sse.data; + trace!("Responses SSE: {:?}", raw); + if let Some(ref tx) = tx_raw_sse { + let _ = tx.send(raw.clone()); + } + if raw == "[DONE]" { + return; + } + + let event_json: Value = match serde_json::from_str(&raw) { + Ok(json) => json, + Err(e) => { + let error_msg = format!("Responses SSE parsing error: {}, data: {}", e, &raw); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + if let Some(api_error_message) = extract_api_error_message(&event_json) { + let error_msg = format!("Responses SSE API error: {}, data: {}", api_error_message, raw); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + + let event: ResponsesStreamEvent = match serde_json::from_value(event_json) { + Ok(event) => event, + Err(e) => { + let error_msg = format!("Responses SSE schema error: {}, data: {}", e, &raw); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + match event.kind.as_str() { + "response.output_text.delta" => { + if let Some(delta) = event.delta.filter(|delta| !delta.is_empty()) { + received_text_delta = true; + let _ = tx_event.send(Ok(UnifiedResponse { + text: Some(delta), + ..Default::default() + })); + } + } + "response.reasoning_text.delta" | "response.reasoning_summary_text.delta" => { + if let Some(delta) = event.delta.filter(|delta| !delta.is_empty()) { + let _ = tx_event.send(Ok(UnifiedResponse { + reasoning_content: Some(delta), + ..Default::default() + })); + } + } + "response.output_item.done" => { + if let Some(item_value) = event.item { + if let Some(mut unified_response) = parse_responses_output_item(item_value) { + if received_text_delta && unified_response.text.is_some() { + unified_response.text = None; + } + if unified_response.text.is_some() || unified_response.tool_call.is_some() { + let _ = tx_event.send(Ok(unified_response)); + } + } + } + } + "response.completed" => { + match event.response.map(serde_json::from_value::) { + Some(Ok(response)) => { + let _ = tx_event.send(Ok(UnifiedResponse { + usage: response.usage.map(Into::into), + finish_reason: Some("stop".to_string()), + ..Default::default() + })); + return; + } + Some(Err(e)) => { + let error_msg = format!("Failed to parse response.completed payload: {}", e); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + None => { + let _ = tx_event.send(Ok(UnifiedResponse { + finish_reason: Some("stop".to_string()), + ..Default::default() + })); + return; + } + } + } + "response.done" => { + match event.response.map(serde_json::from_value::) { + Some(Ok(response)) => { + let _ = tx_event.send(Ok(UnifiedResponse { + usage: response.usage.map(Into::into), + finish_reason: Some("stop".to_string()), + ..Default::default() + })); + return; + } + Some(Err(e)) => { + let error_msg = format!("Failed to parse response.done payload: {}", e); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + None => { + let _ = tx_event.send(Ok(UnifiedResponse { + finish_reason: Some("stop".to_string()), + ..Default::default() + })); + return; + } + } + } + "response.failed" => { + let error_msg = event + .response + .as_ref() + .and_then(|response| response.get("error")) + .and_then(|error| error.get("message")) + .and_then(Value::as_str) + .unwrap_or("Responses API returned response.failed") + .to_string(); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + "response.incomplete" => { + let error_msg = event + .response + .as_ref() + .and_then(|response| response.get("incomplete_details")) + .and_then(|details| details.get("reason")) + .and_then(Value::as_str) + .map(|reason| format!("Incomplete response returned, reason: {}", reason)) + .unwrap_or_else(|| "Incomplete response returned".to_string()); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + _ => {} + } + } +} + +#[cfg(test)] +mod tests { + use super::extract_api_error_message; + use serde_json::json; + + #[test] + fn extracts_api_error_message_from_response_error() { + let event = json!({ + "type": "response.failed", + "response": { + "error": { + "message": "provider error" + } + } + }); + + assert_eq!( + extract_api_error_message(&event).as_deref(), + Some("provider error") + ); + } + + #[test] + fn returns_none_when_no_response_error_exists() { + let event = json!({ + "type": "response.created", + "response": { + "id": "resp_1" + } + }); + + assert!(extract_api_error_message(&event).is_none()); + } + + #[test] + fn returns_none_when_response_error_is_null() { + let event = json!({ + "type": "response.created", + "response": { + "id": "resp_1", + "error": null + } + }); + + assert!(extract_api_error_message(&event).is_none()); + } +} diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/mod.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/mod.rs index 0463a261..496eb3a4 100644 --- a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/mod.rs +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/mod.rs @@ -1,3 +1,4 @@ pub mod unified; pub mod openai; -pub mod anthropic; \ No newline at end of file +pub mod anthropic; +pub mod responses; diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/responses.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/responses.rs new file mode 100644 index 00000000..c5816806 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/responses.rs @@ -0,0 +1,172 @@ +use super::unified::{UnifiedResponse, UnifiedTokenUsage, UnifiedToolCall}; +use serde::Deserialize; +use serde_json::Value; + +#[derive(Debug, Deserialize)] +pub struct ResponsesStreamEvent { + #[serde(rename = "type")] + pub kind: String, + #[serde(default)] + pub response: Option, + #[serde(default)] + pub item: Option, + #[serde(default)] + pub delta: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ResponsesCompleted { + #[allow(dead_code)] + pub id: String, + #[serde(default)] + pub usage: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ResponsesDone { + #[serde(default)] + #[allow(dead_code)] + pub id: Option, + #[serde(default)] + pub usage: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ResponsesUsage { + pub input_tokens: u32, + #[serde(default)] + pub input_tokens_details: Option, + pub output_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Deserialize)] +pub struct ResponsesInputTokensDetails { + pub cached_tokens: u32, +} + +impl From for UnifiedTokenUsage { + fn from(usage: ResponsesUsage) -> Self { + Self { + prompt_token_count: usage.input_tokens, + candidates_token_count: usage.output_tokens, + total_token_count: usage.total_tokens, + cached_content_token_count: usage + .input_tokens_details + .map(|details| details.cached_tokens), + } + } +} + +pub fn parse_responses_output_item(item_value: Value) -> Option { + let item_type = item_value.get("type")?.as_str()?; + + match item_type { + "function_call" => Some(UnifiedResponse { + text: None, + reasoning_content: None, + thinking_signature: None, + tool_call: Some(UnifiedToolCall { + id: item_value + .get("call_id") + .and_then(Value::as_str) + .map(ToString::to_string), + name: item_value + .get("name") + .and_then(Value::as_str) + .map(ToString::to_string), + arguments: item_value + .get("arguments") + .and_then(Value::as_str) + .map(ToString::to_string), + }), + usage: None, + finish_reason: None, + }), + "message" => { + let text = item_value + .get("content") + .and_then(Value::as_array) + .map(|content| { + content + .iter() + .filter(|item| { + item.get("type").and_then(Value::as_str) == Some("output_text") + }) + .filter_map(|item| item.get("text").and_then(Value::as_str)) + .collect::() + }) + .filter(|text| !text.is_empty()); + + text.map(|text| UnifiedResponse { + text: Some(text), + reasoning_content: None, + thinking_signature: None, + tool_call: None, + usage: None, + finish_reason: None, + }) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::{parse_responses_output_item, ResponsesCompleted, ResponsesStreamEvent}; + use serde_json::json; + + #[test] + fn parses_output_text_message_item() { + let response = parse_responses_output_item(json!({ + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "hello" + } + ] + })) + .expect("message item"); + + assert_eq!(response.text.as_deref(), Some("hello")); + } + + #[test] + fn parses_function_call_item() { + let response = parse_responses_output_item(json!({ + "type": "function_call", + "call_id": "call_1", + "name": "get_weather", + "arguments": "{\"city\":\"Beijing\"}" + })) + .expect("function call item"); + + let tool_call = response.tool_call.expect("tool call"); + assert_eq!(tool_call.id.as_deref(), Some("call_1")); + assert_eq!(tool_call.name.as_deref(), Some("get_weather")); + } + + #[test] + fn parses_completed_payload_usage() { + let event: ResponsesStreamEvent = serde_json::from_value(json!({ + "type": "response.completed", + "response": { + "id": "resp_1", + "usage": { + "input_tokens": 10, + "input_tokens_details": { "cached_tokens": 2 }, + "output_tokens": 4, + "total_tokens": 14 + } + } + })) + .expect("event"); + + let completed: ResponsesCompleted = serde_json::from_value(event.response.expect("response")) + .expect("completed"); + assert_eq!(completed.id, "resp_1"); + assert_eq!(completed.usage.expect("usage").total_tokens, 14); + } +} diff --git a/src/crates/core/src/infrastructure/ai/client.rs b/src/crates/core/src/infrastructure/ai/client.rs index c0c25d84..f3ae7508 100644 --- a/src/crates/core/src/infrastructure/ai/client.rs +++ b/src/crates/core/src/infrastructure/ai/client.rs @@ -7,7 +7,9 @@ use crate::infrastructure::ai::providers::openai::OpenAIMessageConverter; use crate::service::config::ProxyConfig; use crate::util::types::*; use crate::util::JsonChecker; -use ai_stream_handlers::{handle_anthropic_stream, handle_openai_stream, UnifiedResponse}; +use ai_stream_handlers::{ + handle_anthropic_stream, handle_openai_stream, handle_responses_stream, UnifiedResponse, +}; use anyhow::{anyhow, Result}; use futures::StreamExt; use log::{debug, error, info, warn}; @@ -95,6 +97,10 @@ impl AIClient { color_letter_stream.contains(Self::TEST_IMAGE_EXPECTED_CODE) } + fn is_responses_api_format(api_format: &str) -> bool { + matches!(api_format.to_ascii_lowercase().as_str(), "response" | "responses") + } + /// Create an AIClient without proxy (backward compatible) pub fn new(config: AIConfig) -> Self { let skip_ssl_verify = config.skip_ssl_verify; @@ -442,6 +448,63 @@ impl AIClient { request_body } + /// Build a Responses API request body. + fn build_responses_request_body( + &self, + instructions: Option, + response_input: Vec, + openai_tools: Option>, + extra_body: Option, + ) -> serde_json::Value { + let mut request_body = serde_json::json!({ + "model": self.config.model, + "input": response_input, + "stream": true + }); + + if let Some(instructions) = instructions.filter(|value| !value.trim().is_empty()) { + request_body["instructions"] = serde_json::Value::String(instructions); + } + + if let Some(max_tokens) = self.config.max_tokens { + request_body["max_output_tokens"] = serde_json::json!(max_tokens); + } + + if let Some(extra) = extra_body { + if let Some(extra_obj) = extra.as_object() { + for (key, value) in extra_obj { + request_body[key] = value.clone(); + } + debug!( + target: "ai::responses_stream_request", + "Applied extra_body overrides: {:?}", + extra_obj.keys().collect::>() + ); + } + } + + debug!( + target: "ai::responses_stream_request", + "Responses stream request body (excluding tools):\n{}", + serde_json::to_string_pretty(&request_body) + .unwrap_or_else(|_| "serialization failed".to_string()) + ); + + if let Some(tools) = openai_tools { + let tool_names = tools + .iter() + .map(|tool| Self::extract_openai_tool_name(tool)) + .collect::>(); + debug!(target: "ai::responses_stream_request", "\ntools: {:?}", tool_names); + if !tools.is_empty() { + request_body["tools"] = serde_json::Value::Array(tools); + request_body["tool_choice"] = serde_json::Value::String("auto".to_string()); + } + } + + request_body + } + /// Build an Anthropic-format request body fn build_anthropic_request_body( &self, @@ -555,6 +618,10 @@ impl AIClient { self.send_openai_stream(messages, tools, extra_body, max_tries) .await } + format if Self::is_responses_api_format(format) => { + self.send_responses_stream(messages, tools, extra_body, max_tries) + .await + } "anthropic" => { self.send_anthropic_stream(messages, tools, extra_body, max_tries) .await @@ -696,6 +763,120 @@ impl AIClient { Err(anyhow!(error_msg)) } + /// Send a Responses API streaming request with retries. + async fn send_responses_stream( + &self, + messages: Vec, + tools: Option>, + extra_body: Option, + max_tries: usize, + ) -> Result { + let url = self.config.request_url.clone(); + debug!( + "Responses config: model={}, request_url={}, max_tries={}", + self.config.model, self.config.request_url, max_tries + ); + + let (instructions, response_input) = + OpenAIMessageConverter::convert_messages_to_responses_input(messages); + let openai_tools = OpenAIMessageConverter::convert_tools(tools); + let request_body = + self.build_responses_request_body(instructions, response_input, openai_tools, extra_body); + + let mut last_error = None; + let base_wait_time_ms = 500; + + for attempt in 0..max_tries { + let request_start_time = std::time::Instant::now(); + let request_builder = self.apply_openai_headers(self.client.post(&url)); + let response_result = request_builder.json(&request_body).send().await; + + let response = match response_result { + Ok(resp) => { + let connect_time = request_start_time.elapsed().as_millis(); + let status = resp.status(); + + if status.is_client_error() { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + error!("Responses API client error {}: {}", status, error_text); + return Err(anyhow!("Responses API client error {}: {}", status, error_text)); + } + + if status.is_success() { + debug!( + "Responses request connected: {}ms, status: {}, attempt: {}/{}", + connect_time, + status, + attempt + 1, + max_tries + ); + resp + } else { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + let error = anyhow!("Responses API error {}: {}", status, error_text); + warn!( + "Responses request failed (attempt {}/{}): {}", + attempt + 1, + max_tries, + error + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + } + Err(e) => { + let connect_time = request_start_time.elapsed().as_millis(); + let error = anyhow!("Responses request connection failed: {}", e); + warn!( + "Responses request connection failed: {}ms, attempt {}/{}, error: {}", + connect_time, + attempt + 1, + max_tries, + e + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!("Retrying after {}ms (attempt {})", delay_ms, attempt + 2); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + }; + + let (tx, rx) = mpsc::unbounded_channel(); + let (tx_raw, rx_raw) = mpsc::unbounded_channel(); + + tokio::spawn(handle_responses_stream(response, tx, Some(tx_raw))); + + return Ok(StreamResponse { + stream: Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)), + raw_sse_rx: Some(rx_raw), + }); + } + + let error_msg = format!( + "Responses request failed after {} attempts: {}", + max_tries, + last_error.unwrap_or_else(|| anyhow!("Unknown error")) + ); + error!("{}", error_msg); + Err(anyhow!(error_msg)) + } + /// Send an Anthropic streaming request with retries /// /// # Parameters diff --git a/src/crates/core/src/infrastructure/ai/providers/openai/message_converter.rs b/src/crates/core/src/infrastructure/ai/providers/openai/message_converter.rs index 7c04e443..0eb1de14 100644 --- a/src/crates/core/src/infrastructure/ai/providers/openai/message_converter.rs +++ b/src/crates/core/src/infrastructure/ai/providers/openai/message_converter.rs @@ -7,12 +7,156 @@ use serde_json::{json, Value}; pub struct OpenAIMessageConverter; impl OpenAIMessageConverter { + pub fn convert_messages_to_responses_input(messages: Vec) -> (Option, Vec) { + let mut instructions = Vec::new(); + let mut input = Vec::new(); + + for msg in messages { + match msg.role.as_str() { + "system" => { + if let Some(content) = msg.content.filter(|content| !content.trim().is_empty()) { + instructions.push(content); + } + } + "tool" => { + if let Some(tool_item) = Self::convert_tool_message_to_responses_item(msg) { + input.push(tool_item); + } + } + "assistant" => { + if let Some(content_items) = Self::convert_message_content_to_responses_items(&msg.role, msg.content.as_deref()) { + input.push(json!({ + "type": "message", + "role": "assistant", + "content": content_items, + })); + } + + if let Some(tool_calls) = msg.tool_calls { + for tool_call in tool_calls { + input.push(json!({ + "type": "function_call", + "call_id": tool_call.id, + "name": tool_call.name, + "arguments": serde_json::to_string(&tool_call.arguments) + .unwrap_or_else(|_| "{}".to_string()), + })); + } + } + } + role => { + if let Some(content_items) = Self::convert_message_content_to_responses_items(role, msg.content.as_deref()) { + input.push(json!({ + "type": "message", + "role": role, + "content": content_items, + })); + } + } + } + } + + let instructions = if instructions.is_empty() { + None + } else { + Some(instructions.join("\n\n")) + }; + + (instructions, input) + } + pub fn convert_messages(messages: Vec) -> Vec { messages.into_iter() .map(Self::convert_single_message) .collect() } + fn convert_tool_message_to_responses_item(msg: Message) -> Option { + let call_id = msg.tool_call_id?; + let output = msg.content.unwrap_or_else(|| "Tool execution completed".to_string()); + + Some(json!({ + "type": "function_call_output", + "call_id": call_id, + "output": output, + })) + } + + fn convert_message_content_to_responses_items(role: &str, content: Option<&str>) -> Option> { + let content = content?; + let text_item_type = Self::responses_text_item_type(role); + + if content.trim().is_empty() { + return Some(vec![json!({ + "type": text_item_type, + "text": " ", + })]); + } + + let parsed = match serde_json::from_str::(content) { + Ok(parsed) if parsed.is_array() => parsed, + _ => { + return Some(vec![json!({ + "type": text_item_type, + "text": content, + })]); + } + }; + + let mut content_items = Vec::new(); + + if let Some(items) = parsed.as_array() { + for item in items { + let item_type = item.get("type").and_then(Value::as_str); + match item_type { + Some("text") | Some("input_text") | Some("output_text") => { + if let Some(text) = item.get("text").and_then(Value::as_str) { + content_items.push(json!({ + "type": text_item_type, + "text": text, + })); + } + } + Some("image_url") if role != "assistant" => { + let image_url = item + .get("image_url") + .and_then(|value| { + value + .get("url") + .and_then(Value::as_str) + .or_else(|| value.as_str()) + }); + + if let Some(image_url) = image_url { + content_items.push(json!({ + "type": "input_image", + "image_url": image_url, + })); + } + } + _ => {} + } + } + } + + if content_items.is_empty() { + Some(vec![json!({ + "type": text_item_type, + "text": content, + })]) + } else { + Some(content_items) + } + } + + fn responses_text_item_type(role: &str) -> &'static str { + if role == "assistant" { + "output_text" + } else { + "input_text" + } + } + fn convert_single_message(msg: Message) -> Value { let mut openai_msg = json!({ "role": msg.role, @@ -125,3 +269,73 @@ impl OpenAIMessageConverter { } } +#[cfg(test)] +mod tests { + use super::OpenAIMessageConverter; + use crate::util::types::{Message, ToolCall}; + use serde_json::json; + use std::collections::HashMap; + + #[test] + fn converts_messages_to_responses_input() { + let mut args = HashMap::new(); + args.insert("city".to_string(), json!("Beijing")); + + let messages = vec![ + Message::system("You are helpful".to_string()), + Message::user("Hello".to_string()), + Message::assistant_with_tools(vec![ToolCall { + id: "call_1".to_string(), + name: "get_weather".to_string(), + arguments: args.clone(), + }]), + Message { + role: "tool".to_string(), + content: Some("Sunny".to_string()), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: Some("call_1".to_string()), + name: Some("get_weather".to_string()), + }, + ]; + + let (instructions, input) = OpenAIMessageConverter::convert_messages_to_responses_input(messages); + + assert_eq!(instructions.as_deref(), Some("You are helpful")); + assert_eq!(input.len(), 3); + assert_eq!(input[0]["type"], json!("message")); + assert_eq!(input[1]["type"], json!("function_call")); + assert_eq!(input[2]["type"], json!("function_call_output")); + } + + #[test] + fn converts_openai_style_image_content_to_responses_input() { + let messages = vec![Message { + role: "user".to_string(), + content: Some(json!([ + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,abc" + } + }, + { + "type": "text", + "text": "Describe this image" + } + ]).to_string()), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: None, + name: None, + }]; + + let (_, input) = OpenAIMessageConverter::convert_messages_to_responses_input(messages); + let content = input[0]["content"].as_array().expect("content array"); + + assert_eq!(content[0]["type"], json!("input_image")); + assert_eq!(content[1]["type"], json!("input_text")); + } +} diff --git a/src/crates/core/src/util/types/config.rs b/src/crates/core/src/util/types/config.rs index bc0212fd..afd0649d 100644 --- a/src/crates/core/src/util/types/config.rs +++ b/src/crates/core/src/util/types/config.rs @@ -25,6 +25,7 @@ fn resolve_request_url(base_url: &str, provider: &str) -> String { match provider.trim().to_ascii_lowercase().as_str() { "openai" => append_endpoint(&trimmed, "chat/completions"), + "response" | "responses" => append_endpoint(&trimmed, "responses"), "anthropic" => append_endpoint(&trimmed, "v1/messages"), _ => trimmed, } @@ -53,6 +54,43 @@ pub struct AIConfig { pub custom_request_body: Option, } +#[cfg(test)] +mod tests { + use super::resolve_request_url; + + #[test] + fn resolves_openai_request_url() { + assert_eq!( + resolve_request_url("https://api.openai.com/v1", "openai"), + "https://api.openai.com/v1/chat/completions" + ); + } + + #[test] + fn resolves_responses_request_url() { + assert_eq!( + resolve_request_url("https://api.openai.com/v1", "responses"), + "https://api.openai.com/v1/responses" + ); + } + + #[test] + fn resolves_response_alias_request_url() { + assert_eq!( + resolve_request_url("https://api.openai.com/v1", "response"), + "https://api.openai.com/v1/responses" + ); + } + + #[test] + fn keeps_forced_request_url() { + assert_eq!( + resolve_request_url("https://api.openai.com/v1/responses#", "responses"), + "https://api.openai.com/v1/responses" + ); + } +} + impl TryFrom for AIConfig { type Error = String; fn try_from(other: AIModelConfig) -> Result>::Error> { diff --git a/src/web-ui/src/features/onboarding/components/steps/ModelConfigStep.tsx b/src/web-ui/src/features/onboarding/components/steps/ModelConfigStep.tsx index 2e95a4d7..eee62346 100644 --- a/src/web-ui/src/features/onboarding/components/steps/ModelConfigStep.tsx +++ b/src/web-ui/src/features/onboarding/components/steps/ModelConfigStep.tsx @@ -33,8 +33,8 @@ export const ModelConfigStep: React.FC = ({ onSkipForNow } const [apiKey, setApiKey] = useState(modelConfig?.apiKey || ''); const [baseUrl, setBaseUrl] = useState(modelConfig?.baseUrl || ''); const [modelName, setModelName] = useState(modelConfig?.modelName || ''); - const [customFormat, setCustomFormat] = useState<'openai' | 'anthropic'>( - (modelConfig?.format as 'openai' | 'anthropic') || 'openai' + const [customFormat, setCustomFormat] = useState<'openai' | 'responses' | 'anthropic'>( + (modelConfig?.format as 'openai' | 'responses' | 'anthropic') || 'openai' ); const [testStatus, setTestStatus] = useState('idle'); const [testError, setTestError] = useState(''); @@ -120,7 +120,7 @@ export const ModelConfigStep: React.FC = ({ onSkipForNow } const effectiveModelName = modelName || (template?.models[0] || ''); // Derive format - let format: 'openai' | 'anthropic' = customFormat; + let format: 'openai' | 'responses' | 'anthropic' = customFormat; if (template) { if (template.baseUrlOptions?.length) { const effectiveUrl = baseUrl || template.baseUrl; @@ -501,10 +501,11 @@ export const ModelConfigStep: React.FC = ({ onSkipForNow } label={t('model.format.label')} options={[ { label: 'OpenAI', value: 'openai' }, + { label: 'OpenAI Responses', value: 'responses' }, { label: 'Anthropic', value: 'anthropic' } ]} value={customFormat} - onChange={(val) => setCustomFormat(val as 'openai' | 'anthropic')} + onChange={(val) => setCustomFormat(val as 'openai' | 'responses' | 'anthropic')} placeholder={t('model.format.placeholder')} /> diff --git a/src/web-ui/src/features/onboarding/store/onboardingStore.ts b/src/web-ui/src/features/onboarding/store/onboardingStore.ts index be1f6796..fcaf6988 100644 --- a/src/web-ui/src/features/onboarding/store/onboardingStore.ts +++ b/src/web-ui/src/features/onboarding/store/onboardingStore.ts @@ -37,7 +37,7 @@ export interface OnboardingModelConfig { modelName?: string; testPassed?: boolean; // Fields needed for saving the model config on completion - format?: 'openai' | 'anthropic'; + format?: 'openai' | 'responses' | 'anthropic'; configName?: string; customRequestBody?: string; skipSslVerify?: boolean; diff --git a/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx b/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx index ed067090..b896adb8 100644 --- a/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx +++ b/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx @@ -25,6 +25,7 @@ const log = createLogger('AIModelConfig'); * Rules: * - Ends with '#' → strip '#', use as-is (force override) * - openai → append '/chat/completions' unless already present + * - responses → append '/responses' unless already present * - anthropic → append '/v1/messages' unless already present * - other → use base_url as-is */ @@ -36,6 +37,9 @@ function resolveRequestUrl(baseUrl: string, provider: string): string { if (provider === 'openai') { return trimmed.endsWith('chat/completions') ? trimmed : `${trimmed}/chat/completions`; } + if (provider === 'response' || provider === 'responses') { + return trimmed.endsWith('responses') ? trimmed : `${trimmed}/responses`; + } if (provider === 'anthropic') { return trimmed.endsWith('v1/messages') ? trimmed : `${trimmed}/v1/messages`; } @@ -69,6 +73,15 @@ const AIModelConfig: React.FC = () => { }); const [isProxySaving, setIsProxySaving] = useState(false); + const requestFormatOptions = useMemo( + () => [ + { label: 'OpenAI (chat/completions)', value: 'openai' }, + { label: 'OpenAI (responses)', value: 'responses' }, + { label: 'Anthropic (messages)', value: 'anthropic' }, + ], + [] + ); + useEffect(() => { loadConfig(); @@ -662,7 +675,7 @@ const AIModelConfig: React.FC = () => { )} - + setEditingConfig(prev => ({ ...prev, model_name: e.target.value }))} placeholder={editingConfig.category === 'speech_recognition' ? 'glm-asr' : 'glm-4.7'} inputSize="small" /> - - setEditingConfig(prev => ({ ...prev, provider: value as string }))} placeholder={t('form.providerPlaceholder')} options={requestFormatOptions} /> {editingConfig.category !== 'speech_recognition' && ( <> @@ -1104,4 +1117,3 @@ const AIModelConfig: React.FC = () => { }; export default AIModelConfig; - diff --git a/src/web-ui/src/infrastructure/config/schemas/ai-models.json b/src/web-ui/src/infrastructure/config/schemas/ai-models.json index f282415c..8a4ae31d 100644 --- a/src/web-ui/src/infrastructure/config/schemas/ai-models.json +++ b/src/web-ui/src/infrastructure/config/schemas/ai-models.json @@ -115,6 +115,11 @@ "label": "OpenAI", "description": "GPT-3.5, GPT-4 等模型" }, + { + "value": "responses", + "label": "OpenAI Responses", + "description": "OpenAI Responses API /responses 兼容模型" + }, { "value": "anthropic", "label": "Anthropic", diff --git a/src/web-ui/src/infrastructure/config/services/modelConfigs.ts b/src/web-ui/src/infrastructure/config/services/modelConfigs.ts index a60b6a36..090be754 100644 --- a/src/web-ui/src/infrastructure/config/services/modelConfigs.ts +++ b/src/web-ui/src/infrastructure/config/services/modelConfigs.ts @@ -306,6 +306,8 @@ export const getFormatDisplayName = (format: ApiFormat): string => { switch (format) { case 'openai': return t('settings/ai-model:formats.openaiCompatible'); + case 'responses': + return t('settings/ai-model:formats.responsesApi'); case 'anthropic': return t('settings/ai-model:formats.claudeApi'); default: diff --git a/src/web-ui/src/locales/en-US/settings/ai-model.json b/src/web-ui/src/locales/en-US/settings/ai-model.json index ae9cbcf8..2c49ab73 100644 --- a/src/web-ui/src/locales/en-US/settings/ai-model.json +++ b/src/web-ui/src/locales/en-US/settings/ai-model.json @@ -149,6 +149,7 @@ }, "formats": { "openaiCompatible": "OpenAI Compatible", + "responsesApi": "OpenAI Responses API", "claudeApi": "Claude API" }, "actions": { diff --git a/src/web-ui/src/locales/zh-CN/settings/ai-model.json b/src/web-ui/src/locales/zh-CN/settings/ai-model.json index c915b15c..9ff37610 100644 --- a/src/web-ui/src/locales/zh-CN/settings/ai-model.json +++ b/src/web-ui/src/locales/zh-CN/settings/ai-model.json @@ -100,7 +100,7 @@ "searchApiHint": "智谱AI搜索API地址,或其他兼容的搜索API地址", "apiKey": "API密钥", "apiKeyPlaceholder": "输入您的 API Key", - "provider": "请求规范", + "provider": "请求格式", "providerPlaceholder": "选择请求格式", "contextWindow": "上下文窗口大小", "contextWindowHint": "模型支持的最大上下文窗口大小(用于计算token占用百分比)", @@ -118,7 +118,7 @@ "details": { "basicInfo": "基本信息", "modelName": "模型名称", - "provider": "请求规范", + "provider": "请求格式", "apiUrl": "API地址", "contextWindow": "上下文窗口", "maxOutput": "最大输出", @@ -149,6 +149,7 @@ }, "formats": { "openaiCompatible": "OpenAI 兼容", + "responsesApi": "OpenAI Responses API", "claudeApi": "Claude API" }, "actions": { diff --git a/src/web-ui/src/shared/types/chat.ts b/src/web-ui/src/shared/types/chat.ts index b312cb21..dc1b3690 100644 --- a/src/web-ui/src/shared/types/chat.ts +++ b/src/web-ui/src/shared/types/chat.ts @@ -11,7 +11,7 @@ export type MessageStatus = 'pending' | 'sending' | 'sent' | 'error'; export type ConversationStatus = 'pending' | 'completed' | 'failed' | 'cancelled'; -export type ApiFormat = 'openai' | 'anthropic'; +export type ApiFormat = 'openai' | 'responses' | 'anthropic'; export interface ToolExecution { From 845d711be1852ca0debee0e7a4ab2eee709144aa Mon Sep 17 00:00:00 2001 From: wgqqqqq Date: Mon, 9 Mar 2026 15:32:50 +0800 Subject: [PATCH 2/4] feat(ai): add Gemini provider support --- .../ai/ai_stream_handlers/src/lib.rs | 1 + .../src/stream_handler/gemini.rs | 248 +++++++ .../src/stream_handler/mod.rs | 2 + .../ai/ai_stream_handlers/src/types/gemini.rs | 247 +++++++ .../ai/ai_stream_handlers/src/types/mod.rs | 1 + .../core/src/infrastructure/ai/client.rs | 279 +++++++- .../ai/providers/gemini/message_converter.rs | 669 ++++++++++++++++++ .../infrastructure/ai/providers/gemini/mod.rs | 5 + .../src/infrastructure/ai/providers/mod.rs | 3 +- src/crates/core/src/util/types/config.rs | 60 +- .../components/steps/ModelConfigStep.tsx | 13 +- .../onboarding/store/onboardingStore.ts | 2 +- .../config/components/AIModelConfig.tsx | 46 +- .../config/services/modelConfigs.ts | 13 + .../src/locales/en-US/settings/ai-model.json | 7 +- .../src/locales/zh-CN/settings/ai-model.json | 7 +- src/web-ui/src/shared/types/chat.ts | 2 +- 17 files changed, 1572 insertions(+), 33 deletions(-) create mode 100644 src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/gemini.rs create mode 100644 src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/gemini.rs create mode 100644 src/crates/core/src/infrastructure/ai/providers/gemini/message_converter.rs create mode 100644 src/crates/core/src/infrastructure/ai/providers/gemini/mod.rs diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/lib.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/lib.rs index 2aed04b0..d10cee11 100644 --- a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/lib.rs +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/lib.rs @@ -2,6 +2,7 @@ mod stream_handler; mod types; pub use stream_handler::handle_anthropic_stream; +pub use stream_handler::handle_gemini_stream; pub use stream_handler::handle_openai_stream; pub use stream_handler::handle_responses_stream; pub use types::unified::{UnifiedResponse, UnifiedTokenUsage, UnifiedToolCall}; diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/gemini.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/gemini.rs new file mode 100644 index 00000000..395ea7d8 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/gemini.rs @@ -0,0 +1,248 @@ +use crate::types::gemini::GeminiSSEData; +use crate::types::unified::UnifiedResponse; +use anyhow::{anyhow, Result}; +use eventsource_stream::Eventsource; +use futures::StreamExt; +use log::{error, trace}; +use reqwest::Response; +use serde_json::Value; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::timeout; + +static GEMINI_STREAM_ID_SEQ: AtomicU64 = AtomicU64::new(1); + +#[derive(Debug)] +struct GeminiToolCallState { + active_name: Option, + active_id: Option, + stream_id: u64, + next_index: usize, +} + +impl GeminiToolCallState { + fn new() -> Self { + Self { + active_name: None, + active_id: None, + stream_id: GEMINI_STREAM_ID_SEQ.fetch_add(1, Ordering::Relaxed), + next_index: 0, + } + } + + fn on_non_tool_response(&mut self) { + self.active_name = None; + self.active_id = None; + } + + fn assign_id(&mut self, tool_call: &mut crate::types::unified::UnifiedToolCall) { + if let Some(existing_id) = tool_call.id.as_ref().filter(|value| !value.is_empty()) { + self.active_id = Some(existing_id.clone()); + self.active_name = tool_call.name.clone().filter(|value| !value.is_empty()); + return; + } + + let tool_name = tool_call.name.clone().filter(|value| !value.is_empty()); + let is_same_active_call = self.active_id.is_some() && self.active_name == tool_name; + + if is_same_active_call { + tool_call.id = None; + return; + } + + self.next_index += 1; + let generated_id = format!("gemini_call_{}_{}", self.stream_id, self.next_index); + tool_call.id = Some(generated_id.clone()); + self.active_id = Some(generated_id); + self.active_name = tool_name; + } +} + +fn extract_api_error_message(event_json: &Value) -> Option { + let error = event_json.get("error")?; + if let Some(message) = error.get("message").and_then(Value::as_str) { + return Some(message.to_string()); + } + if let Some(message) = error.as_str() { + return Some(message.to_string()); + } + Some("Gemini streaming request failed".to_string()) +} + +pub async fn handle_gemini_stream( + response: Response, + tx_event: mpsc::UnboundedSender>, + tx_raw_sse: Option>, +) { + let mut stream = response.bytes_stream().eventsource(); + let idle_timeout = Duration::from_secs(600); + let mut received_finish_reason = false; + let mut tool_call_state = GeminiToolCallState::new(); + + loop { + let sse_event = timeout(idle_timeout, stream.next()).await; + let sse = match sse_event { + Ok(Some(Ok(sse))) => sse, + Ok(None) => { + if received_finish_reason { + return; + } + let error_msg = "Gemini SSE stream closed before response completed"; + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Ok(Some(Err(e))) => { + let error_msg = format!("Gemini SSE stream error: {}", e); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + Err(_) => { + let error_msg = format!( + "Gemini SSE stream timeout after {}s", + idle_timeout.as_secs() + ); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + let raw = sse.data; + trace!("Gemini SSE: {:?}", raw); + + if let Some(ref tx) = tx_raw_sse { + let _ = tx.send(raw.clone()); + } + + if raw == "[DONE]" { + return; + } + + let event_json: Value = match serde_json::from_str(&raw) { + Ok(json) => json, + Err(e) => { + let error_msg = format!("Gemini SSE parsing error: {}, data: {}", e, raw); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + if let Some(message) = extract_api_error_message(&event_json) { + let error_msg = format!("Gemini SSE API error: {}, data: {}", message, raw); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + + let sse_data: GeminiSSEData = match serde_json::from_value(event_json) { + Ok(data) => data, + Err(e) => { + let error_msg = format!("Gemini SSE data schema error: {}, data: {}", e, raw); + error!("{}", error_msg); + let _ = tx_event.send(Err(anyhow!(error_msg))); + return; + } + }; + + let mut unified_responses = sse_data.into_unified_responses(); + for unified_response in &mut unified_responses { + if let Some(tool_call) = unified_response.tool_call.as_mut() { + tool_call_state.assign_id(tool_call); + } else { + tool_call_state.on_non_tool_response(); + } + + if unified_response.finish_reason.is_some() { + received_finish_reason = true; + tool_call_state.on_non_tool_response(); + } + } + + for unified_response in unified_responses { + let _ = tx_event.send(Ok(unified_response)); + } + } +} + +#[cfg(test)] +mod tests { + use super::GeminiToolCallState; + use crate::types::unified::UnifiedToolCall; + + #[test] + fn reuses_active_tool_id_by_omitting_follow_up_ids() { + let mut state = GeminiToolCallState::new(); + + let mut first = UnifiedToolCall { + id: None, + name: Some("get_weather".to_string()), + arguments: Some("{\"city\":".to_string()), + }; + state.assign_id(&mut first); + + let mut second = UnifiedToolCall { + id: None, + name: Some("get_weather".to_string()), + arguments: Some("\"Paris\"}".to_string()), + }; + state.assign_id(&mut second); + + assert!(first + .id + .as_deref() + .is_some_and(|id| id.starts_with("gemini_call_"))); + assert!(second.id.is_none()); + } + + #[test] + fn clears_active_tool_after_non_tool_response() { + let mut state = GeminiToolCallState::new(); + + let mut first = UnifiedToolCall { + id: None, + name: Some("get_weather".to_string()), + arguments: Some("{}".to_string()), + }; + state.assign_id(&mut first); + state.on_non_tool_response(); + + let mut second = UnifiedToolCall { + id: None, + name: Some("get_weather".to_string()), + arguments: Some("{}".to_string()), + }; + state.assign_id(&mut second); + + let first_id = first.id.expect("first id"); + let second_id = second.id.expect("second id"); + assert!(first_id.starts_with("gemini_call_")); + assert!(second_id.starts_with("gemini_call_")); + assert_ne!(first_id, second_id); + } + + #[test] + fn generates_unique_prefixes_across_streams() { + let mut first_state = GeminiToolCallState::new(); + let mut second_state = GeminiToolCallState::new(); + + let mut first = UnifiedToolCall { + id: None, + name: Some("grep".to_string()), + arguments: Some("{}".to_string()), + }; + let mut second = UnifiedToolCall { + id: None, + name: Some("read".to_string()), + arguments: Some("{}".to_string()), + }; + + first_state.assign_id(&mut first); + second_state.assign_id(&mut second); + + assert_ne!(first.id, second.id); + } +} diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/mod.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/mod.rs index b117f3c7..24e2938a 100644 --- a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/mod.rs +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/stream_handler/mod.rs @@ -1,7 +1,9 @@ mod openai; mod anthropic; mod responses; +mod gemini; pub use openai::handle_openai_stream; pub use anthropic::handle_anthropic_stream; pub use responses::handle_responses_stream; +pub use gemini::handle_gemini_stream; diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/gemini.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/gemini.rs new file mode 100644 index 00000000..ed344d28 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/gemini.rs @@ -0,0 +1,247 @@ +use crate::types::unified::{UnifiedResponse, UnifiedTokenUsage, UnifiedToolCall}; +use serde::Deserialize; +use serde_json::{json, Value}; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiSSEData { + #[serde(default)] + pub candidates: Vec, + #[serde(default)] + pub usage_metadata: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiCandidate { + #[serde(default)] + pub content: Option, + #[serde(default)] + pub finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiContent { + #[serde(default)] + pub parts: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiPart { + #[serde(default)] + pub text: Option, + #[serde(default)] + pub thought: Option, + #[serde(default)] + pub thought_signature: Option, + #[serde(default)] + pub function_call: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiFunctionCall { + #[serde(default)] + pub name: Option, + #[serde(default)] + pub args: Option, +} + +#[derive(Debug, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct GeminiUsageMetadata { + #[serde(default)] + pub prompt_token_count: u32, + #[serde(default)] + pub candidates_token_count: u32, + #[serde(default)] + pub total_token_count: u32, + #[serde(default)] + pub cached_content_token_count: Option, +} + +impl From for UnifiedTokenUsage { + fn from(usage: GeminiUsageMetadata) -> Self { + Self { + prompt_token_count: usage.prompt_token_count, + candidates_token_count: usage.candidates_token_count, + total_token_count: usage.total_token_count, + cached_content_token_count: usage.cached_content_token_count, + } + } +} + +impl GeminiSSEData { + pub fn into_unified_responses(self) -> Vec { + let mut usage = self.usage_metadata.map(Into::into); + let Some(candidate) = self.candidates.into_iter().next() else { + return usage + .take() + .map(|usage| { + vec![UnifiedResponse { + usage: Some(usage), + ..Default::default() + }] + }) + .unwrap_or_default(); + }; + + let mut responses = Vec::new(); + let mut finish_reason = candidate.finish_reason; + + if let Some(content) = candidate.content { + for part in content.parts { + let has_function_call = part.function_call.is_some(); + let text = part.text.filter(|text| !text.is_empty()); + let is_thought = part.thought.unwrap_or(false); + let thinking_signature = part.thought_signature.filter(|value| !value.is_empty()); + + if let Some(function_call) = part.function_call { + let arguments = function_call.args.unwrap_or_else(|| json!({})); + responses.push(UnifiedResponse { + text: None, + reasoning_content: None, + thinking_signature, + tool_call: Some(UnifiedToolCall { + id: None, + name: function_call.name, + arguments: serde_json::to_string(&arguments).ok(), + }), + usage: usage.take(), + finish_reason: finish_reason.take(), + }); + continue; + } + + if let Some(text) = text { + responses.push(UnifiedResponse { + text: if is_thought { None } else { Some(text.clone()) }, + reasoning_content: if is_thought { Some(text) } else { None }, + thinking_signature, + tool_call: None, + usage: usage.take(), + finish_reason: finish_reason.take(), + }); + continue; + } + + if thinking_signature.is_some() && !has_function_call { + responses.push(UnifiedResponse { + text: None, + reasoning_content: None, + thinking_signature, + tool_call: None, + usage: usage.take(), + finish_reason: finish_reason.take(), + }); + } + } + } + + if responses.is_empty() { + responses.push(UnifiedResponse { + usage, + finish_reason, + ..Default::default() + }); + } + + responses + } +} + +#[cfg(test)] +mod tests { + use super::GeminiSSEData; + + #[test] + fn converts_text_thought_and_usage() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { + "parts": [ + { "text": "thinking", "thought": true, "thoughtSignature": "sig_1" }, + { "text": "answer" } + ] + }, + "finishReason": "STOP" + }], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 4, + "totalTokenCount": 14 + } + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 2); + assert_eq!(responses[0].reasoning_content.as_deref(), Some("thinking")); + assert_eq!(responses[0].thinking_signature.as_deref(), Some("sig_1")); + assert_eq!( + responses[0] + .usage + .as_ref() + .map(|usage| usage.total_token_count), + Some(14) + ); + assert_eq!(responses[1].text.as_deref(), Some("answer")); + } + + #[test] + fn keeps_thought_signature_on_function_call_parts() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { + "parts": [ + { + "thoughtSignature": "sig_tool", + "functionCall": { + "name": "get_weather", + "args": { "city": "Paris" } + } + } + ] + } + }] + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 1); + assert_eq!(responses[0].thinking_signature.as_deref(), Some("sig_tool")); + assert_eq!( + responses[0] + .tool_call + .as_ref() + .and_then(|tool_call| tool_call.name.as_deref()), + Some("get_weather") + ); + } + + #[test] + fn keeps_standalone_thought_signature_parts() { + let payload = serde_json::json!({ + "candidates": [{ + "content": { + "parts": [ + { "thoughtSignature": "sig_only" } + ] + } + }] + }); + + let data: GeminiSSEData = serde_json::from_value(payload).expect("gemini payload"); + let responses = data.into_unified_responses(); + + assert_eq!(responses.len(), 1); + assert_eq!(responses[0].thinking_signature.as_deref(), Some("sig_only")); + assert!(responses[0].tool_call.is_none()); + assert!(responses[0].text.is_none()); + assert!(responses[0].reasoning_content.is_none()); + } +} diff --git a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/mod.rs b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/mod.rs index 496eb3a4..c266edbd 100644 --- a/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/mod.rs +++ b/src/crates/core/src/infrastructure/ai/ai_stream_handlers/src/types/mod.rs @@ -2,3 +2,4 @@ pub mod unified; pub mod openai; pub mod anthropic; pub mod responses; +pub mod gemini; diff --git a/src/crates/core/src/infrastructure/ai/client.rs b/src/crates/core/src/infrastructure/ai/client.rs index f3ae7508..3c9b2709 100644 --- a/src/crates/core/src/infrastructure/ai/client.rs +++ b/src/crates/core/src/infrastructure/ai/client.rs @@ -3,12 +3,13 @@ //! Uses a modular architecture to separate provider-specific logic into the providers module use crate::infrastructure::ai::providers::anthropic::AnthropicMessageConverter; +use crate::infrastructure::ai::providers::gemini::GeminiMessageConverter; use crate::infrastructure::ai::providers::openai::OpenAIMessageConverter; use crate::service::config::ProxyConfig; use crate::util::types::*; use crate::util::JsonChecker; use ai_stream_handlers::{ - handle_anthropic_stream, handle_openai_stream, handle_responses_stream, UnifiedResponse, + handle_anthropic_stream, handle_gemini_stream, handle_openai_stream, handle_responses_stream, UnifiedResponse, }; use anyhow::{anyhow, Result}; use futures::StreamExt; @@ -101,6 +102,10 @@ impl AIClient { matches!(api_format.to_ascii_lowercase().as_str(), "response" | "responses") } + fn is_gemini_api_format(api_format: &str) -> bool { + matches!(api_format.to_ascii_lowercase().as_str(), "gemini" | "google") + } + /// Create an AIClient without proxy (backward compatible) pub fn new(config: AIConfig) -> Self { let skip_ssl_verify = config.skip_ssl_verify; @@ -374,6 +379,33 @@ impl AIClient { builder } + /// Apply Gemini-style request headers (merge/replace). + fn apply_gemini_headers( + &self, + mut builder: reqwest::RequestBuilder, + ) -> reqwest::RequestBuilder { + let has_custom_headers = self + .config + .custom_headers + .as_ref() + .map_or(false, |h| !h.is_empty()); + let is_merge_mode = self.is_merge_headers_mode(); + + if has_custom_headers && !is_merge_mode { + return self.apply_custom_headers(builder); + } + + builder = builder + .header("Content-Type", "application/json") + .header("x-goog-api-key", &self.config.api_key); + + if has_custom_headers && is_merge_mode { + builder = self.apply_custom_headers(builder); + } + + builder + } + /// Build an OpenAI-format request body fn build_openai_request_body( &self, @@ -571,6 +603,116 @@ impl AIClient { request_body } + /// Build a Gemini-format request body. + fn build_gemini_request_body( + &self, + system_instruction: Option, + contents: Vec, + gemini_tools: Option>, + extra_body: Option, + ) -> serde_json::Value { + let mut request_body = serde_json::json!({ + "contents": contents, + }); + + if let Some(system_instruction) = system_instruction { + request_body["systemInstruction"] = system_instruction; + } + + if let Some(max_tokens) = self.config.max_tokens { + request_body["generationConfig"] = serde_json::json!({ + "maxOutputTokens": max_tokens, + }); + } + + if self.config.enable_thinking_process { + if request_body.get("generationConfig").is_none() { + request_body["generationConfig"] = serde_json::json!({}); + } + request_body["generationConfig"]["thinkingConfig"] = serde_json::json!({ + "includeThoughts": true, + }); + } + + if let Some(tools) = gemini_tools { + let tool_names = tools + .iter() + .flat_map(|tool| { + tool.get("functionDeclarations") + .and_then(|value| value.as_array()) + .into_iter() + .flatten() + .filter_map(|declaration| { + declaration + .get("name") + .and_then(|value| value.as_str()) + .map(str::to_string) + }) + }) + .collect::>(); + debug!(target: "ai::gemini_stream_request", "\ntools: {:?}", tool_names); + + if !tools.is_empty() { + request_body["tools"] = serde_json::Value::Array(tools); + request_body["toolConfig"] = serde_json::json!({ + "functionCallingConfig": { + "mode": "AUTO" + } + }); + } + } + + if let Some(extra) = extra_body { + if let Some(extra_obj) = extra.as_object() { + for (key, value) in extra_obj { + request_body[key] = value.clone(); + } + debug!( + target: "ai::gemini_stream_request", + "Applied extra_body overrides: {:?}", + extra_obj.keys().collect::>() + ); + } + } + + debug!( + target: "ai::gemini_stream_request", + "Gemini stream request body:\n{}", + serde_json::to_string_pretty(&request_body) + .unwrap_or_else(|_| "serialization failed".to_string()) + ); + + request_body + } + + fn resolve_gemini_request_url(base_url: &str, model_name: &str) -> String { + let trimmed = base_url.trim().trim_end_matches('/'); + if trimmed.is_empty() { + return String::new(); + } + + let mut url = trimmed + .replace(":generateContent", ":streamGenerateContent") + .replace(":streamGenerateContent?alt=sse", ":streamGenerateContent"); + + if !url.contains(":streamGenerateContent") { + if url.contains("/models/") { + url = format!("{}:streamGenerateContent", url); + } else { + let encoded_model = urlencoding::encode(model_name); + url = format!("{}/models/{}:streamGenerateContent", url, encoded_model); + } + } + + if url.contains("alt=sse") { + url + } else if url.contains('?') { + format!("{}&alt=sse", url) + } else { + format!("{}?alt=sse", url) + } + } + fn extract_openai_tool_name(tool: &serde_json::Value) -> String { tool.get("function") .and_then(|f| f.get("name")) @@ -618,6 +760,10 @@ impl AIClient { self.send_openai_stream(messages, tools, extra_body, max_tries) .await } + format if Self::is_gemini_api_format(format) => { + self.send_gemini_stream(messages, tools, extra_body, max_tries) + .await + } format if Self::is_responses_api_format(format) => { self.send_responses_stream(messages, tools, extra_body, max_tries) .await @@ -763,6 +909,137 @@ impl AIClient { Err(anyhow!(error_msg)) } + /// Send a Gemini streaming request with retries. + async fn send_gemini_stream( + &self, + messages: Vec, + tools: Option>, + extra_body: Option, + max_tries: usize, + ) -> Result { + let url = Self::resolve_gemini_request_url(&self.config.request_url, &self.config.model); + debug!( + "Gemini config: model={}, request_url={}, max_tries={}", + self.config.model, url, max_tries + ); + + let (system_instruction, contents) = + GeminiMessageConverter::convert_messages(messages, &self.config.model); + let gemini_tools = GeminiMessageConverter::convert_tools(tools); + let request_body = + self.build_gemini_request_body(system_instruction, contents, gemini_tools, extra_body); + + let mut last_error = None; + let base_wait_time_ms = 500; + + for attempt in 0..max_tries { + let request_start_time = std::time::Instant::now(); + let request_builder = self.apply_gemini_headers(self.client.post(&url)); + let response_result = request_builder.json(&request_body).send().await; + + let response = match response_result { + Ok(resp) => { + let connect_time = request_start_time.elapsed().as_millis(); + let status = resp.status(); + + if status.is_client_error() { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + error!( + "Gemini Streaming API client error {}: {}", + status, error_text + ); + return Err(anyhow!( + "Gemini Streaming API client error {}: {}", + status, + error_text + )); + } + + if status.is_success() { + debug!( + "Gemini stream request connected: {}ms, status: {}, attempt: {}/{}", + connect_time, + status, + attempt + 1, + max_tries + ); + resp + } else { + let error_text = resp + .text() + .await + .unwrap_or_else(|e| format!("Failed to read error response: {}", e)); + let error = + anyhow!("Gemini Streaming API error {}: {}", status, error_text); + warn!( + "Gemini stream request failed: {}ms, attempt {}/{}, error: {}", + connect_time, + attempt + 1, + max_tries, + error + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!( + "Retrying Gemini after {}ms (attempt {})", + delay_ms, + attempt + 2 + ); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + } + Err(e) => { + let connect_time = request_start_time.elapsed().as_millis(); + let error = anyhow!("Gemini stream request connection failed: {}", e); + warn!( + "Gemini stream request connection failed: {}ms, attempt {}/{}, error: {}", + connect_time, + attempt + 1, + max_tries, + e + ); + last_error = Some(error); + + if attempt < max_tries - 1 { + let delay_ms = base_wait_time_ms * (1 << attempt.min(3)); + debug!( + "Retrying Gemini after {}ms (attempt {})", + delay_ms, + attempt + 2 + ); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + } + continue; + } + }; + + let (tx, rx) = mpsc::unbounded_channel(); + let (tx_raw, rx_raw) = mpsc::unbounded_channel(); + + tokio::spawn(handle_gemini_stream(response, tx, Some(tx_raw))); + + return Ok(StreamResponse { + stream: Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)), + raw_sse_rx: Some(rx_raw), + }); + } + + let error_msg = format!( + "Gemini stream request failed after {} attempts: {}", + max_tries, + last_error.unwrap_or_else(|| anyhow!("Unknown error")) + ); + error!("{}", error_msg); + Err(anyhow!(error_msg)) + } + /// Send a Responses API streaming request with retries. async fn send_responses_stream( &self, diff --git a/src/crates/core/src/infrastructure/ai/providers/gemini/message_converter.rs b/src/crates/core/src/infrastructure/ai/providers/gemini/message_converter.rs new file mode 100644 index 00000000..4c71742c --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/providers/gemini/message_converter.rs @@ -0,0 +1,669 @@ +//! Gemini message format converter + +use crate::util::types::{Message, ToolDefinition}; +use log::warn; +use serde_json::{json, Map, Value}; + +pub struct GeminiMessageConverter; + +impl GeminiMessageConverter { + pub fn convert_messages(messages: Vec, model_name: &str) -> (Option, Vec) { + let mut system_texts = Vec::new(); + let mut contents = Vec::new(); + let is_gemini_3 = model_name.contains("gemini-3"); + + for msg in messages { + match msg.role.as_str() { + "system" => { + if let Some(content) = msg.content.filter(|content| !content.trim().is_empty()) + { + system_texts.push(content); + } + } + "user" => { + let parts = Self::convert_content_parts(msg.content.as_deref(), false); + Self::push_content(&mut contents, "user", parts); + } + "assistant" => { + let mut parts = Vec::new(); + + let mut pending_thought_signature = msg + .thinking_signature + .filter(|value| !value.trim().is_empty()); + let has_tool_calls = msg + .tool_calls + .as_ref() + .map(|tool_calls| !tool_calls.is_empty()) + .unwrap_or(false); + + if let Some(content) = msg.content.as_deref().filter(|value| !value.trim().is_empty()) { + if !has_tool_calls { + if let Some(signature) = pending_thought_signature.take() { + parts.push(json!({ + "thoughtSignature": signature, + })); + } + } + parts.extend(Self::convert_content_parts(Some(content), true)); + } + + if let Some(tool_calls) = msg.tool_calls { + for (tool_call_index, tool_call) in tool_calls.into_iter().enumerate() { + let mut part = Map::new(); + part.insert( + "functionCall".to_string(), + json!({ + "name": tool_call.name, + "args": tool_call.arguments, + }), + ); + + match pending_thought_signature.take() { + Some(signature) => { + part.insert( + "thoughtSignature".to_string(), + Value::String(signature), + ); + } + None if is_gemini_3 && tool_call_index == 0 => { + part.insert( + "thoughtSignature".to_string(), + Value::String( + "skip_thought_signature_validator".to_string(), + ), + ); + } + None => {} + } + + parts.push(Value::Object(part)); + } + } + + if let Some(signature) = pending_thought_signature { + parts.push(json!({ + "thoughtSignature": signature, + })); + } + + Self::push_content(&mut contents, "model", parts); + } + "tool" => { + let tool_name = msg.name.unwrap_or_default(); + if tool_name.is_empty() { + warn!("Skipping Gemini tool response without tool name"); + continue; + } + + let response = Self::parse_tool_response(msg.content.as_deref()); + let parts = vec![json!({ + "functionResponse": { + "name": tool_name, + "response": response, + } + })]; + + Self::push_content(&mut contents, "user", parts); + } + _ => { + warn!("Unknown Gemini message role: {}", msg.role); + } + } + } + + let system_instruction = if system_texts.is_empty() { + None + } else { + Some(json!({ + "parts": [{ + "text": system_texts.join("\n\n") + }] + })) + }; + + (system_instruction, contents) + } + + pub fn convert_tools(tools: Option>) -> Option> { + tools.and_then(|tool_defs| { + let declarations: Vec = tool_defs + .into_iter() + .map(|tool| { + let parameters = Self::strip_unsupported_schema_fields(tool.parameters); + json!({ + "name": tool.name, + "description": tool.description, + "parameters": parameters, + }) + }) + .collect(); + + if declarations.is_empty() { + None + } else { + Some(vec![json!({ + "functionDeclarations": declarations, + })]) + } + }) + } + + fn push_content(contents: &mut Vec, role: &str, parts: Vec) { + if parts.is_empty() { + return; + } + + if let Some(last) = contents.last_mut() { + let last_role = last.get("role").and_then(Value::as_str).unwrap_or_default(); + if last_role == role { + if let Some(existing_parts) = last.get_mut("parts").and_then(Value::as_array_mut) { + existing_parts.extend(parts); + return; + } + } + } + + contents.push(json!({ + "role": role, + "parts": parts, + })); + } + + fn convert_content_parts(content: Option<&str>, is_model_role: bool) -> Vec { + let Some(content) = content else { + return Vec::new(); + }; + + if content.trim().is_empty() { + return Vec::new(); + } + + let parsed = match serde_json::from_str::(content) { + Ok(parsed) if parsed.is_array() => parsed, + _ => return vec![json!({ "text": content })], + }; + + let mut parts = Vec::new(); + + if let Some(items) = parsed.as_array() { + for item in items { + let item_type = item.get("type").and_then(Value::as_str); + match item_type { + Some("text") | Some("input_text") | Some("output_text") => { + if let Some(text) = item.get("text").and_then(Value::as_str) { + if !text.is_empty() { + parts.push(json!({ "text": text })); + } + } + } + Some("image_url") if !is_model_role => { + if let Some(url) = item.get("image_url").and_then(|value| { + value + .get("url") + .and_then(Value::as_str) + .or_else(|| value.as_str()) + }) { + if let Some(part) = Self::convert_image_url_to_part(url) { + parts.push(part); + } + } + } + Some("image") if !is_model_role => { + let source = item.get("source"); + let mime_type = source + .and_then(|value| value.get("media_type")) + .and_then(Value::as_str); + let data = source + .and_then(|value| value.get("data")) + .and_then(Value::as_str); + + if let (Some(mime_type), Some(data)) = (mime_type, data) { + parts.push(json!({ + "inlineData": { + "mimeType": mime_type, + "data": data, + } + })); + } + } + _ => {} + } + } + } + + if parts.is_empty() { + vec![json!({ "text": content })] + } else { + parts + } + } + + fn convert_image_url_to_part(url: &str) -> Option { + let prefix = "data:"; + if !url.starts_with(prefix) { + warn!("Gemini currently supports inline data URLs for image parts; skipping unsupported image URL"); + return None; + } + + let rest = &url[prefix.len()..]; + let (mime_type, data) = rest.split_once(";base64,")?; + if mime_type.is_empty() || data.is_empty() { + return None; + } + + Some(json!({ + "inlineData": { + "mimeType": mime_type, + "data": data, + } + })) + } + + fn parse_tool_response(content: Option<&str>) -> Value { + let Some(content) = content.filter(|value| !value.trim().is_empty()) else { + return json!({ "content": "Tool execution completed" }); + }; + + match serde_json::from_str::(content) { + Ok(Value::Object(map)) => Value::Object(map), + Ok(value) => json!({ "content": value }), + Err(_) => json!({ "content": content }), + } + } + + fn strip_unsupported_schema_fields(value: Value) -> Value { + match value { + Value::Object(mut map) => { + let all_of = map.remove("allOf"); + let any_of = map.remove("anyOf"); + let one_of = map.remove("oneOf"); + let (normalized_type, nullable_from_type) = + Self::normalize_schema_type(map.remove("type")); + + let mut sanitized = Map::new(); + for (key, value) in map { + if key == "properties" { + if let Value::Object(properties) = value { + sanitized.insert( + key, + Value::Object( + properties + .into_iter() + .map(|(name, schema)| { + (name, Self::strip_unsupported_schema_fields(schema)) + }) + .collect(), + ), + ); + } + continue; + } + + if Self::is_supported_schema_key(&key) { + sanitized.insert(key, Self::strip_unsupported_schema_fields(value)); + } + } + + if let Some(all_of) = all_of { + Self::merge_schema_variants(&mut sanitized, all_of, true); + } + + let mut nullable = nullable_from_type; + if let Some(any_of) = any_of { + nullable |= Self::merge_union_variants(&mut sanitized, any_of); + } + if let Some(one_of) = one_of { + nullable |= Self::merge_union_variants(&mut sanitized, one_of); + } + + if let Some(schema_type) = normalized_type { + sanitized.insert("type".to_string(), Value::String(schema_type)); + } + if nullable { + sanitized.insert("nullable".to_string(), Value::Bool(true)); + } + + Value::Object(sanitized) + } + Value::Array(items) => Value::Array( + items + .into_iter() + .map(Self::strip_unsupported_schema_fields) + .collect(), + ), + other => other, + } + } + + fn is_supported_schema_key(key: &str) -> bool { + matches!( + key, + "type" + | "format" + | "description" + | "nullable" + | "enum" + | "items" + | "properties" + | "required" + | "minItems" + | "maxItems" + | "minimum" + | "maximum" + | "minLength" + | "maxLength" + | "pattern" + ) + } + + fn normalize_schema_type(type_value: Option) -> (Option, bool) { + match type_value { + Some(Value::String(value)) if value != "null" => (Some(value), false), + Some(Value::String(_)) => (None, true), + Some(Value::Array(values)) => { + let mut types = values.into_iter().filter_map(|value| value.as_str().map(str::to_string)); + let mut nullable = false; + let mut selected = None; + + for value in types.by_ref() { + if value == "null" { + nullable = true; + } else if selected.is_none() { + selected = Some(value); + } + } + + (selected, nullable) + } + _ => (None, false), + } + } + + fn merge_union_variants(target: &mut Map, variants: Value) -> bool { + let mut nullable = false; + + if let Value::Array(variants) = variants { + for variant in variants { + let sanitized = Self::strip_unsupported_schema_fields(variant); + match sanitized { + Value::Object(map) => { + let is_null_only = map + .get("type") + .and_then(Value::as_str) + .map(|value| value == "null") + .unwrap_or(false) + && map.len() == 1; + + if is_null_only { + nullable = true; + continue; + } + + Self::merge_schema_map(target, map, false); + } + Value::String(value) if value == "null" => nullable = true, + _ => {} + } + } + } + + nullable + } + + fn merge_schema_variants(target: &mut Map, variants: Value, preserve_required: bool) { + if let Value::Array(variants) = variants { + for variant in variants { + if let Value::Object(map) = Self::strip_unsupported_schema_fields(variant) { + Self::merge_schema_map(target, map, preserve_required); + } + } + } + } + + fn merge_schema_map( + target: &mut Map, + source: Map, + preserve_required: bool, + ) { + for (key, value) in source { + match key.as_str() { + "properties" => { + if let Value::Object(source_props) = value { + let target_props = target + .entry(key) + .or_insert_with(|| Value::Object(Map::new())); + if let Value::Object(target_props) = target_props { + for (prop_key, prop_value) in source_props { + target_props.entry(prop_key).or_insert(prop_value); + } + } + } + } + "required" if preserve_required => { + if let Value::Array(source_required) = value { + let target_required = target + .entry(key) + .or_insert_with(|| Value::Array(Vec::new())); + if let Value::Array(target_required) = target_required { + for item in source_required { + if !target_required.contains(&item) { + target_required.push(item); + } + } + } + } + } + "nullable" => { + if value.as_bool().unwrap_or(false) { + target.insert(key, Value::Bool(true)); + } + } + "type" => { + target.entry(key).or_insert(value); + } + _ => { + target.entry(key).or_insert(value); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::GeminiMessageConverter; + use crate::util::types::{Message, ToolCall, ToolDefinition}; + use serde_json::json; + use std::collections::HashMap; + + #[test] + fn converts_messages_to_gemini_format() { + let mut args = HashMap::new(); + args.insert("city".to_string(), json!("Beijing")); + + let messages = vec![ + Message::system("You are helpful".to_string()), + Message::user("Hello".to_string()), + Message { + role: "assistant".to_string(), + content: Some("Working on it".to_string()), + reasoning_content: Some("Let me think".to_string()), + thinking_signature: Some("sig_1".to_string()), + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + name: "get_weather".to_string(), + arguments: args.clone(), + }]), + tool_call_id: None, + name: None, + }, + Message { + role: "tool".to_string(), + content: Some("Sunny".to_string()), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: Some("call_1".to_string()), + name: Some("get_weather".to_string()), + }, + ]; + + let (system_instruction, contents) = + GeminiMessageConverter::convert_messages(messages, "gemini-2.5-pro"); + + assert_eq!( + system_instruction.unwrap()["parts"][0]["text"], + json!("You are helpful") + ); + assert_eq!(contents.len(), 3); + assert_eq!(contents[0]["role"], json!("user")); + assert_eq!(contents[1]["role"], json!("model")); + assert_eq!(contents[1]["parts"][0]["text"], json!("Working on it")); + assert_eq!( + contents[1]["parts"][1]["functionCall"]["name"], + json!("get_weather") + ); + assert_eq!(contents[1]["parts"][1]["thoughtSignature"], json!("sig_1")); + assert_eq!( + contents[2]["parts"][0]["functionResponse"]["name"], + json!("get_weather") + ); + } + + #[test] + fn injects_skip_signature_for_first_synthetic_gemini_3_tool_call() { + let mut args = HashMap::new(); + args.insert("city".to_string(), json!("Paris")); + + let messages = vec![Message { + role: "assistant".to_string(), + content: None, + reasoning_content: None, + thinking_signature: None, + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + name: "get_weather".to_string(), + arguments: args, + }]), + tool_call_id: None, + name: None, + }]; + + let (_, contents) = + GeminiMessageConverter::convert_messages(messages, "gemini-3-flash-preview"); + + assert_eq!(contents.len(), 1); + assert_eq!( + contents[0]["parts"][0]["thoughtSignature"], + json!("skip_thought_signature_validator") + ); + } + + #[test] + fn converts_data_url_images_to_inline_data() { + let messages = vec![Message { + role: "user".to_string(), + content: Some( + json!([ + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,abc" + } + }, + { + "type": "text", + "text": "Describe this image" + } + ]) + .to_string(), + ), + reasoning_content: None, + thinking_signature: None, + tool_calls: None, + tool_call_id: None, + name: None, + }]; + + let (_, contents) = GeminiMessageConverter::convert_messages(messages, "gemini-2.5-pro"); + + assert_eq!( + contents[0]["parts"][0]["inlineData"]["mimeType"], + json!("image/png") + ); + assert_eq!( + contents[0]["parts"][1]["text"], + json!("Describe this image") + ); + } + + #[test] + fn strips_unsupported_fields_from_tool_schema() { + let tools = Some(vec![ToolDefinition { + name: "get_weather".to_string(), + description: "Get weather".to_string(), + parameters: json!({ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "city": { "type": "string" }, + "timezone": { + "type": ["string", "null"] + }, + "link": { + "anyOf": [ + { + "type": "object", + "properties": { + "url": { "type": "string" } + }, + "required": ["url"] + }, + { "type": "null" } + ] + }, + "items": { + "allOf": [ + { + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "required": ["name"] + }, + { + "type": "object", + "properties": { + "count": { "type": "integer" } + }, + "required": ["count"] + } + ] + } + }, + "required": ["city"], + "additionalProperties": false, + "items": { + "type": "object", + "additionalProperties": false + } + }), + }]); + + let converted = GeminiMessageConverter::convert_tools(tools).expect("converted tools"); + let schema = &converted[0]["functionDeclarations"][0]["parameters"]; + + assert!(schema.get("$schema").is_none()); + assert!(schema.get("additionalProperties").is_none()); + assert!(schema["items"].get("additionalProperties").is_none()); + assert_eq!(schema["properties"]["timezone"]["type"], json!("string")); + assert_eq!(schema["properties"]["timezone"]["nullable"], json!(true)); + assert_eq!(schema["properties"]["link"]["type"], json!("object")); + assert_eq!(schema["properties"]["link"]["nullable"], json!(true)); + assert_eq!(schema["properties"]["items"]["type"], json!("object")); + assert_eq!( + schema["properties"]["items"]["required"], + json!(["name", "count"]) + ); + } +} diff --git a/src/crates/core/src/infrastructure/ai/providers/gemini/mod.rs b/src/crates/core/src/infrastructure/ai/providers/gemini/mod.rs new file mode 100644 index 00000000..ee6d89d2 --- /dev/null +++ b/src/crates/core/src/infrastructure/ai/providers/gemini/mod.rs @@ -0,0 +1,5 @@ +//! Gemini provider module + +pub mod message_converter; + +pub use message_converter::GeminiMessageConverter; diff --git a/src/crates/core/src/infrastructure/ai/providers/mod.rs b/src/crates/core/src/infrastructure/ai/providers/mod.rs index 61ce45c6..d0e806ae 100644 --- a/src/crates/core/src/infrastructure/ai/providers/mod.rs +++ b/src/crates/core/src/infrastructure/ai/providers/mod.rs @@ -4,6 +4,7 @@ pub mod openai; pub mod anthropic; +pub mod gemini; pub use anthropic::AnthropicMessageConverter; - +pub use gemini::GeminiMessageConverter; diff --git a/src/crates/core/src/util/types/config.rs b/src/crates/core/src/util/types/config.rs index afd0649d..8ef7873d 100644 --- a/src/crates/core/src/util/types/config.rs +++ b/src/crates/core/src/util/types/config.rs @@ -13,7 +13,42 @@ fn append_endpoint(base_url: &str, endpoint: &str) -> String { format!("{}/{}", base.trim_end_matches('/'), endpoint) } -fn resolve_request_url(base_url: &str, provider: &str) -> String { +fn resolve_gemini_request_url(base_url: &str, model_name: &str) -> String { + let trimmed = base_url.trim().trim_end_matches('/').to_string(); + if trimmed.is_empty() { + return String::new(); + } + + if let Some(stripped) = trimmed.strip_suffix('#') { + return stripped.trim_end_matches('/').to_string(); + } + + let stream_endpoint = ":streamGenerateContent?alt=sse"; + if trimmed.contains(":generateContent") { + return trimmed.replace(":generateContent", stream_endpoint); + } + if trimmed.contains(":streamGenerateContent") { + if trimmed.contains("alt=sse") { + return trimmed; + } + if trimmed.contains('?') { + return format!("{}&alt=sse", trimmed); + } + return format!("{}?alt=sse", trimmed); + } + if trimmed.contains("/models/") { + return format!("{}{}", trimmed, stream_endpoint); + } + + let model = model_name.trim(); + if model.is_empty() { + return trimmed; + } + + append_endpoint(&trimmed, &format!("models/{}{}", model, stream_endpoint)) +} + +fn resolve_request_url(base_url: &str, provider: &str, model_name: &str) -> String { let trimmed = base_url.trim().trim_end_matches('/').to_string(); if trimmed.is_empty() { return String::new(); @@ -27,6 +62,7 @@ fn resolve_request_url(base_url: &str, provider: &str) -> String { "openai" => append_endpoint(&trimmed, "chat/completions"), "response" | "responses" => append_endpoint(&trimmed, "responses"), "anthropic" => append_endpoint(&trimmed, "v1/messages"), + "gemini" | "google" => resolve_gemini_request_url(&trimmed, model_name), _ => trimmed, } } @@ -61,7 +97,7 @@ mod tests { #[test] fn resolves_openai_request_url() { assert_eq!( - resolve_request_url("https://api.openai.com/v1", "openai"), + resolve_request_url("https://api.openai.com/v1", "openai", ""), "https://api.openai.com/v1/chat/completions" ); } @@ -69,7 +105,7 @@ mod tests { #[test] fn resolves_responses_request_url() { assert_eq!( - resolve_request_url("https://api.openai.com/v1", "responses"), + resolve_request_url("https://api.openai.com/v1", "responses", ""), "https://api.openai.com/v1/responses" ); } @@ -77,7 +113,7 @@ mod tests { #[test] fn resolves_response_alias_request_url() { assert_eq!( - resolve_request_url("https://api.openai.com/v1", "response"), + resolve_request_url("https://api.openai.com/v1", "response", ""), "https://api.openai.com/v1/responses" ); } @@ -85,10 +121,22 @@ mod tests { #[test] fn keeps_forced_request_url() { assert_eq!( - resolve_request_url("https://api.openai.com/v1/responses#", "responses"), + resolve_request_url("https://api.openai.com/v1/responses#", "responses", ""), "https://api.openai.com/v1/responses" ); } + + #[test] + fn resolves_gemini_request_url() { + assert_eq!( + resolve_request_url( + "https://generativelanguage.googleapis.com/v1beta", + "gemini", + "gemini-2.5-pro" + ), + "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro:streamGenerateContent?alt=sse" + ); + } } impl TryFrom for AIConfig { @@ -111,7 +159,7 @@ impl TryFrom for AIConfig { let request_url = other .request_url .filter(|u| !u.is_empty()) - .unwrap_or_else(|| resolve_request_url(&other.base_url, &other.provider)); + .unwrap_or_else(|| resolve_request_url(&other.base_url, &other.provider, &other.model_name)); Ok(AIConfig { name: other.name.clone(), diff --git a/src/web-ui/src/features/onboarding/components/steps/ModelConfigStep.tsx b/src/web-ui/src/features/onboarding/components/steps/ModelConfigStep.tsx index eee62346..132cbe0f 100644 --- a/src/web-ui/src/features/onboarding/components/steps/ModelConfigStep.tsx +++ b/src/web-ui/src/features/onboarding/components/steps/ModelConfigStep.tsx @@ -19,7 +19,7 @@ interface ModelConfigStepProps { } /** Provider display order */ -const PROVIDER_ORDER = ['zhipu', 'qwen', 'deepseek', 'volcengine', 'minimax', 'moonshot', 'anthropic']; +const PROVIDER_ORDER = ['zhipu', 'qwen', 'deepseek', 'volcengine', 'minimax', 'moonshot', 'gemini', 'anthropic']; type TestStatus = 'idle' | 'testing' | 'success' | 'error'; @@ -33,8 +33,8 @@ export const ModelConfigStep: React.FC = ({ onSkipForNow } const [apiKey, setApiKey] = useState(modelConfig?.apiKey || ''); const [baseUrl, setBaseUrl] = useState(modelConfig?.baseUrl || ''); const [modelName, setModelName] = useState(modelConfig?.modelName || ''); - const [customFormat, setCustomFormat] = useState<'openai' | 'responses' | 'anthropic'>( - (modelConfig?.format as 'openai' | 'responses' | 'anthropic') || 'openai' + const [customFormat, setCustomFormat] = useState<'openai' | 'responses' | 'anthropic' | 'gemini'>( + (modelConfig?.format as 'openai' | 'responses' | 'anthropic' | 'gemini') || 'openai' ); const [testStatus, setTestStatus] = useState('idle'); const [testError, setTestError] = useState(''); @@ -120,7 +120,7 @@ export const ModelConfigStep: React.FC = ({ onSkipForNow } const effectiveModelName = modelName || (template?.models[0] || ''); // Derive format - let format: 'openai' | 'responses' | 'anthropic' = customFormat; + let format: 'openai' | 'responses' | 'anthropic' | 'gemini' = customFormat; if (template) { if (template.baseUrlOptions?.length) { const effectiveUrl = baseUrl || template.baseUrl; @@ -502,10 +502,11 @@ export const ModelConfigStep: React.FC = ({ onSkipForNow } options={[ { label: 'OpenAI', value: 'openai' }, { label: 'OpenAI Responses', value: 'responses' }, - { label: 'Anthropic', value: 'anthropic' } + { label: 'Anthropic', value: 'anthropic' }, + { label: 'Gemini', value: 'gemini' } ]} value={customFormat} - onChange={(val) => setCustomFormat(val as 'openai' | 'responses' | 'anthropic')} + onChange={(val) => setCustomFormat(val as 'openai' | 'responses' | 'anthropic' | 'gemini')} placeholder={t('model.format.placeholder')} /> diff --git a/src/web-ui/src/features/onboarding/store/onboardingStore.ts b/src/web-ui/src/features/onboarding/store/onboardingStore.ts index fcaf6988..1451b378 100644 --- a/src/web-ui/src/features/onboarding/store/onboardingStore.ts +++ b/src/web-ui/src/features/onboarding/store/onboardingStore.ts @@ -37,7 +37,7 @@ export interface OnboardingModelConfig { modelName?: string; testPassed?: boolean; // Fields needed for saving the model config on completion - format?: 'openai' | 'responses' | 'anthropic'; + format?: 'openai' | 'responses' | 'anthropic' | 'gemini'; configName?: string; customRequestBody?: string; skipSslVerify?: boolean; diff --git a/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx b/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx index b896adb8..4c063176 100644 --- a/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx +++ b/src/web-ui/src/infrastructure/config/components/AIModelConfig.tsx @@ -27,9 +27,10 @@ const log = createLogger('AIModelConfig'); * - openai → append '/chat/completions' unless already present * - responses → append '/responses' unless already present * - anthropic → append '/v1/messages' unless already present + * - gemini → append '/models/{model}:streamGenerateContent?alt=sse' * - other → use base_url as-is */ -function resolveRequestUrl(baseUrl: string, provider: string): string { +function resolveRequestUrl(baseUrl: string, provider: string, modelName = ''): string { const trimmed = baseUrl.trim().replace(/\/+$/, ''); if (trimmed.endsWith('#')) { return trimmed.slice(0, -1).replace(/\/+$/, ''); @@ -43,6 +44,19 @@ function resolveRequestUrl(baseUrl: string, provider: string): string { if (provider === 'anthropic') { return trimmed.endsWith('v1/messages') ? trimmed : `${trimmed}/v1/messages`; } + if (provider === 'gemini') { + if (!modelName.trim()) return trimmed; + if (trimmed.includes(':generateContent')) { + return trimmed.replace(':generateContent', ':streamGenerateContent?alt=sse'); + } + if (trimmed.includes(':streamGenerateContent')) { + return trimmed.includes('alt=sse') ? trimmed : `${trimmed}${trimmed.includes('?') ? '&' : '?'}alt=sse`; + } + if (trimmed.includes('/models/')) { + return `${trimmed}:streamGenerateContent?alt=sse`; + } + return `${trimmed}/models/${modelName}:streamGenerateContent?alt=sse`; + } return trimmed; } @@ -78,6 +92,7 @@ const AIModelConfig: React.FC = () => { { label: 'OpenAI (chat/completions)', value: 'openai' }, { label: 'OpenAI (responses)', value: 'responses' }, { label: 'Anthropic (messages)', value: 'anthropic' }, + { label: 'Gemini (generateContent)', value: 'gemini' }, ], [] ); @@ -101,7 +116,7 @@ const AIModelConfig: React.FC = () => { }; // Provider options with translations (must be at top level, before any conditional returns) - const providerOrder = ['zhipu', 'qwen', 'deepseek', 'volcengine', 'minimax', 'moonshot', 'anthropic']; + const providerOrder = ['zhipu', 'qwen', 'deepseek', 'volcengine', 'minimax', 'moonshot', 'gemini', 'anthropic']; const providers = useMemo(() => { const sorted = Object.values(PROVIDER_TEMPLATES).sort((a, b) => { const indexA = providerOrder.indexOf(a.id); @@ -154,7 +169,7 @@ const AIModelConfig: React.FC = () => { setEditingConfig({ name: defaultModel ? `${providerName} - ${defaultModel}` : '', base_url: template.baseUrl, - request_url: resolveRequestUrl(template.baseUrl, template.format), + request_url: resolveRequestUrl(template.baseUrl, template.format, defaultModel), api_key: '', model_name: defaultModel, provider: template.format, @@ -229,7 +244,7 @@ const AIModelConfig: React.FC = () => { id: editingConfig.id || `model_${Date.now()}`, name: editingConfig.name, base_url: editingConfig.base_url, - request_url: editingConfig.request_url || resolveRequestUrl(editingConfig.base_url, editingConfig.provider || 'openai'), + request_url: editingConfig.request_url || resolveRequestUrl(editingConfig.base_url, editingConfig.provider || 'openai', editingConfig.model_name || ''), api_key: editingConfig.api_key || '', model_name: editingConfig.model_name || 'search-api', provider: editingConfig.provider || 'openai', @@ -555,17 +570,17 @@ const AIModelConfig: React.FC = () => { case 'general_chat': defaultCapabilities = ['text_chat', 'function_calling']; updates.base_url = 'https://open.bigmodel.cn/api/paas/v4/chat/completions'; - updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai'); + updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai', prev?.model_name || ''); break; case 'multimodal': defaultCapabilities = ['text_chat', 'image_understanding', 'function_calling']; updates.base_url = 'https://open.bigmodel.cn/api/paas/v4/chat/completions'; - updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai'); + updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai', prev?.model_name || ''); break; case 'image_generation': defaultCapabilities = ['image_generation']; updates.base_url = 'https://open.bigmodel.cn/api/paas/v4/images/generations'; - updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai'); + updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai', prev?.model_name || ''); break; case 'search_enhanced': defaultCapabilities = ['search']; @@ -579,7 +594,7 @@ const AIModelConfig: React.FC = () => { case 'speech_recognition': defaultCapabilities = ['speech_recognition']; updates.base_url = 'https://open.bigmodel.cn/api/paas/v4/chat/completions'; - updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai'); + updates.request_url = resolveRequestUrl(updates.base_url!, prev?.provider || 'openai', prev?.model_name || ''); break; } updates.capabilities = defaultCapabilities; @@ -614,6 +629,7 @@ const AIModelConfig: React.FC = () => { return { ...prev, model_name: newModelName, + request_url: resolveRequestUrl(prev?.base_url || currentTemplate?.baseUrl || '', prev?.provider || currentTemplate?.format || 'openai', newModelName), name: isAutoGenerated && currentTemplate ? `${currentTemplate.name} - ${newModelName}` : prev?.name }; }); @@ -642,7 +658,7 @@ const AIModelConfig: React.FC = () => { setEditingConfig(prev => ({ ...prev, base_url: value as string, - request_url: resolveRequestUrl(value as string, newProvider), + request_url: resolveRequestUrl(value as string, newProvider, editingConfig.model_name || ''), provider: newProvider })); }} @@ -657,7 +673,7 @@ const AIModelConfig: React.FC = () => { onChange={(e) => setEditingConfig(prev => ({ ...prev, base_url: e.target.value, - request_url: resolveRequestUrl(e.target.value, prev?.provider || 'openai') + request_url: resolveRequestUrl(e.target.value, prev?.provider || 'openai', prev?.model_name || '') }))} onFocus={(e) => e.target.select()} placeholder={currentTemplate?.baseUrl} @@ -667,7 +683,7 @@ const AIModelConfig: React.FC = () => {
{t('form.resolvedUrlLabel')} - {resolveRequestUrl(editingConfig.base_url, editingConfig.provider || 'openai')} + {resolveRequestUrl(editingConfig.base_url, editingConfig.provider || 'openai', editingConfig.model_name || '')} {t('form.forceUrlHint')}
@@ -681,7 +697,7 @@ const AIModelConfig: React.FC = () => { onChange={(value) => setEditingConfig(prev => ({ ...prev, provider: value as string, - request_url: resolveRequestUrl(prev?.base_url || '', value as string) + request_url: resolveRequestUrl(prev?.base_url || '', value as string, prev?.model_name || '') }))} placeholder={t('form.providerPlaceholder')} options={requestFormatOptions} @@ -719,7 +735,7 @@ const AIModelConfig: React.FC = () => { onChange={(e) => setEditingConfig(prev => ({ ...prev, base_url: e.target.value, - request_url: resolveRequestUrl(e.target.value, prev?.provider || 'openai') + request_url: resolveRequestUrl(e.target.value, prev?.provider || 'openai', prev?.model_name || '') }))} onFocus={(e) => e.target.select()} placeholder={editingConfig.category === 'search_enhanced' ? 'https://open.bigmodel.cn/api/paas/v4/web_search' : 'https://open.bigmodel.cn/api/paas/v4/chat/completions'} @@ -729,7 +745,7 @@ const AIModelConfig: React.FC = () => {
{t('form.resolvedUrlLabel')} - {resolveRequestUrl(editingConfig.base_url, editingConfig.provider || 'openai')} + {resolveRequestUrl(editingConfig.base_url, editingConfig.provider || 'openai', editingConfig.model_name || '')} {t('form.forceUrlHint')}
@@ -745,7 +761,7 @@ const AIModelConfig: React.FC = () => { {!isFromTemplate && editingConfig.category !== 'search_enhanced' && ( <> - setEditingConfig(prev => ({ ...prev, model_name: e.target.value }))} placeholder={editingConfig.category === 'speech_recognition' ? 'glm-asr' : 'glm-4.7'} inputSize="small" /> + setEditingConfig(prev => ({ ...prev, model_name: e.target.value, request_url: resolveRequestUrl(prev?.base_url || '', prev?.provider || 'openai', e.target.value) }))} placeholder={editingConfig.category === 'speech_recognition' ? 'glm-asr' : 'glm-4.7'} inputSize="small" /> setEditingConfig(prev => ({ ...prev, reasoning_effort: (v as string) || undefined }))} placeholder={t('reasoningEffort.placeholder')} options={reasoningEffortOptions} /> + + )} ) : ( <> @@ -764,7 +785,15 @@ const AIModelConfig: React.FC = () => { setEditingConfig(prev => ({ ...prev, model_name: e.target.value, request_url: resolveRequestUrl(prev?.base_url || '', prev?.provider || 'openai', e.target.value) }))} placeholder={editingConfig.category === 'speech_recognition' ? 'glm-asr' : 'glm-4.7'} inputSize="small" />
- { + const provider = value as string; + setEditingConfig(prev => ({ + ...prev, + provider, + request_url: resolveRequestUrl(prev?.base_url || '', provider, prev?.model_name || ''), + reasoning_effort: isResponsesProvider(provider) ? (prev?.reasoning_effort || 'medium') : undefined, + })); + }} placeholder={t('form.providerPlaceholder')} options={requestFormatOptions} /> {editingConfig.category !== 'speech_recognition' && ( <> @@ -781,6 +810,11 @@ const AIModelConfig: React.FC = () => { setEditingConfig(prev => ({ ...prev, enable_thinking_process: e.target.checked }))} size="small" /> + {isResponsesProvider(editingConfig.provider) && ( + +