diff --git a/crates/openfang-kernel/src/kernel.rs b/crates/openfang-kernel/src/kernel.rs index d9fe60f971..2f292a0602 100644 --- a/crates/openfang-kernel/src/kernel.rs +++ b/crates/openfang-kernel/src/kernel.rs @@ -881,12 +881,14 @@ impl OpenFangKernel { // Auto-detect embedding provider by checking API key env vars in // priority order. First match wins. const API_KEY_PROVIDERS: &[(&str, &str)] = &[ - ("OPENAI_API_KEY", "openai"), - ("GROQ_API_KEY", "groq"), - ("MISTRAL_API_KEY", "mistral"), - ("TOGETHER_API_KEY", "together"), + ("GEMINI_API_KEY", "gemini"), + ("GOOGLE_API_KEY", "gemini"), + ("OPENAI_API_KEY", "openai"), + ("GROQ_API_KEY", "groq"), + ("MISTRAL_API_KEY", "mistral"), + ("TOGETHER_API_KEY", "together"), ("FIREWORKS_API_KEY", "fireworks"), - ("COHERE_API_KEY", "cohere"), + ("COHERE_API_KEY", "cohere"), ]; let detected_from_key = API_KEY_PROVIDERS @@ -1127,8 +1129,7 @@ impl OpenFangKernel { != entry.manifest.tool_allowlist || disk_manifest.tool_blocklist != entry.manifest.tool_blocklist - || disk_manifest.skills - != entry.manifest.skills + || disk_manifest.skills != entry.manifest.skills || disk_manifest.mcp_servers != entry.manifest.mcp_servers; if changed { @@ -5700,6 +5701,7 @@ fn apply_budget_defaults( fn default_embedding_model_for_provider(provider: &str) -> &'static str { match provider { "openai" => "text-embedding-3-small", + "gemini" | "google" => "gemini-embedding-2-preview", "groq" => "nomic-embed-text", "mistral" => "mistral-embed", "together" => "togethercomputer/m2-bert-80M-8k-retrieval", diff --git a/crates/openfang-runtime/src/embedding.rs b/crates/openfang-runtime/src/embedding.rs index c3245d879c..600ec6dc6c 100644 --- a/crates/openfang-runtime/src/embedding.rs +++ b/crates/openfang-runtime/src/embedding.rs @@ -6,10 +6,10 @@ use async_trait::async_trait; use openfang_types::model_catalog::{ - FIREWORKS_BASE_URL, GROQ_BASE_URL, LMSTUDIO_BASE_URL, MISTRAL_BASE_URL, OLLAMA_BASE_URL, - OPENAI_BASE_URL, TOGETHER_BASE_URL, VLLM_BASE_URL, + FIREWORKS_BASE_URL, GEMINI_BASE_URL, GROQ_BASE_URL, LMSTUDIO_BASE_URL, MISTRAL_BASE_URL, + OLLAMA_BASE_URL, OPENAI_BASE_URL, TOGETHER_BASE_URL, VLLM_BASE_URL, }; -use serde::{Deserialize, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tracing::{debug, warn}; use zeroize::Zeroizing; @@ -70,6 +70,15 @@ pub struct OpenAIEmbeddingDriver { dims: usize, } +pub struct GeminiEmbeddingDriver { + api_key: Zeroizing, + base_url: String, + resource_model: String, + client: reqwest::Client, + dims: usize, + mode: std::sync::atomic::AtomicU8, +} + #[derive(Serialize)] struct EmbedRequest<'a> { model: &'a str, @@ -86,6 +95,62 @@ struct EmbedData { embedding: Vec, } +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct GeminiEmbedRequest<'a> { + requests: Vec>, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct GeminiEmbedItem<'a> { + model: &'a str, + content: GeminiEmbedContent<'a>, + output_dimensionality: usize, +} + +#[derive(Serialize)] +struct GeminiEmbedContent<'a> { + parts: Vec>, +} + +#[derive(Serialize)] +struct GeminiEmbedPart<'a> { + text: &'a str, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct GeminiEmbedResponse { + embeddings: Vec, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct GeminiEmbedValue { + values: Vec, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct GeminiSingleEmbedRequest<'a> { + model: &'a str, + content: GeminiEmbedContent<'a>, + output_dimensionality: usize, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct GeminiSingleEmbedResponse { + embedding: GeminiEmbedValue, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum GeminiEmbeddingMode { + Batch = 0, + Single = 1, +} + impl OpenAIEmbeddingDriver { /// Create a new OpenAI-compatible embedding driver. pub fn new(config: EmbeddingConfig) -> Result { @@ -102,6 +167,31 @@ impl OpenAIEmbeddingDriver { } } +fn normalize_gemini_model_name(model: &str) -> String { + if model.starts_with("models/") { + model.to_string() + } else { + format!("models/{model}") + } +} + +impl GeminiEmbeddingDriver { + /// Create a new Gemini embedding driver. + pub fn new(config: EmbeddingConfig) -> Result { + let resource_model = normalize_gemini_model_name(&config.model); + let dims = infer_dimensions(resource_model.trim_start_matches("models/")); + + Ok(Self { + api_key: Zeroizing::new(config.api_key), + base_url: config.base_url, + resource_model, + client: reqwest::Client::new(), + dims, + mode: std::sync::atomic::AtomicU8::new(GeminiEmbeddingMode::Batch as u8), + }) + } +} + /// Infer embedding dimensions from model name. fn infer_dimensions(model: &str) -> usize { match model { @@ -109,6 +199,7 @@ fn infer_dimensions(model: &str) -> usize { "text-embedding-3-small" => 1536, "text-embedding-3-large" => 3072, "text-embedding-ada-002" => 1536, + "gemini-embedding-2-preview" => 3072, // Sentence Transformers / local models "all-MiniLM-L6-v2" => 384, "all-MiniLM-L12-v2" => 384, @@ -120,6 +211,156 @@ fn infer_dimensions(model: &str) -> usize { } } +#[async_trait] +impl EmbeddingDriver for GeminiEmbeddingDriver { + async fn embed(&self, texts: &[&str]) -> Result>, EmbeddingError> { + if texts.is_empty() { + return Ok(vec![]); + } + + match self.mode() { + GeminiEmbeddingMode::Single => self.embed_sequential(texts).await, + GeminiEmbeddingMode::Batch => match self.embed_batch(texts).await { + Ok(embeddings) => Ok(embeddings), + Err(err) if self.should_fallback_to_single(&err) => { + debug!( + error = %err, + "Gemini batch embeddings unavailable; falling back to single embedContent requests" + ); + let embeddings = self.embed_sequential(texts).await?; + self.set_mode(GeminiEmbeddingMode::Single); + Ok(embeddings) + } + Err(err) => Err(err), + }, + } + } + + fn dimensions(&self) -> usize { + self.dims + } +} + +impl GeminiEmbeddingDriver { + fn mode(&self) -> GeminiEmbeddingMode { + match self.mode.load(std::sync::atomic::Ordering::Relaxed) { + x if x == GeminiEmbeddingMode::Single as u8 => GeminiEmbeddingMode::Single, + _ => GeminiEmbeddingMode::Batch, + } + } + + fn set_mode(&self, mode: GeminiEmbeddingMode) { + self.mode + .store(mode as u8, std::sync::atomic::Ordering::Relaxed); + } + + fn should_fallback_to_single(&self, err: &EmbeddingError) -> bool { + match err { + EmbeddingError::Api { status, message } => { + let lower = message.to_ascii_lowercase(); + *status == 403 + || *status == 404 + || (*status == 400 + && (lower.contains("batch") + || lower.contains("not supported") + || lower.contains("permission_denied") + || lower.contains("permission denied") + || lower.contains("method not found") + || lower.contains("unimplemented"))) + } + _ => false, + } + } + + async fn embed_batch(&self, texts: &[&str]) -> Result>, EmbeddingError> { + let url = format!( + "{}/v1beta/models/{}:batchEmbedContents", + self.base_url.trim_end_matches('/'), + self.resource_model.trim_start_matches("models/") + ); + let body = GeminiEmbedRequest { + requests: texts + .iter() + .map(|text| GeminiEmbedItem { + model: &self.resource_model, + content: GeminiEmbedContent { + parts: vec![GeminiEmbedPart { text }], + }, + output_dimensionality: self.dims, + }) + .collect(), + }; + + let resp = self + .client + .post(&url) + .header("x-goog-api-key", self.api_key.as_str()) + .json(&body) + .send() + .await + .map_err(|e| EmbeddingError::Http(e.to_string()))?; + + Self::parse_gemini_response(resp) + .await + .map(|data: GeminiEmbedResponse| { + data.embeddings.into_iter().map(|d| d.values).collect() + }) + } + + async fn embed_sequential(&self, texts: &[&str]) -> Result>, EmbeddingError> { + let mut embeddings = Vec::with_capacity(texts.len()); + for text in texts { + embeddings.push(self.embed_single(text).await?); + } + Ok(embeddings) + } + + async fn embed_single(&self, text: &str) -> Result, EmbeddingError> { + let url = format!( + "{}/v1beta/{}:embedContent", + self.base_url.trim_end_matches('/'), + self.resource_model + ); + let body = GeminiSingleEmbedRequest { + model: &self.resource_model, + content: GeminiEmbedContent { + parts: vec![GeminiEmbedPart { text }], + }, + output_dimensionality: self.dims, + }; + + let resp = self + .client + .post(&url) + .header("x-goog-api-key", self.api_key.as_str()) + .json(&body) + .send() + .await + .map_err(|e| EmbeddingError::Http(e.to_string()))?; + + Self::parse_gemini_response(resp) + .await + .map(|data: GeminiSingleEmbedResponse| data.embedding.values) + } + + async fn parse_gemini_response( + resp: reqwest::Response, + ) -> Result { + let status = resp.status().as_u16(); + if status != 200 { + let body_text = resp.text().await.unwrap_or_default(); + return Err(EmbeddingError::Api { + status, + message: body_text, + }); + } + + resp.json() + .await + .map_err(|e| EmbeddingError::Parse(e.to_string())) + } +} + #[async_trait] impl EmbeddingDriver for OpenAIEmbeddingDriver { async fn embed(&self, texts: &[&str]) -> Result>, EmbeddingError> { @@ -213,6 +454,7 @@ pub fn create_embedding_driver( }) .unwrap_or_else(|| match provider { "openai" => OPENAI_BASE_URL.to_string(), + "gemini" | "google" => GEMINI_BASE_URL.to_string(), "groq" => GROQ_BASE_URL.to_string(), "together" => TOGETHER_BASE_URL.to_string(), "fireworks" => FIREWORKS_BASE_URL.to_string(), @@ -245,8 +487,13 @@ pub fn create_embedding_driver( base_url, }; - let driver = OpenAIEmbeddingDriver::new(config)?; - Ok(Box::new(driver)) + if provider == "gemini" || provider == "google" { + let driver = GeminiEmbeddingDriver::new(config)?; + Ok(Box::new(driver)) + } else { + let driver = OpenAIEmbeddingDriver::new(config)?; + Ok(Box::new(driver)) + } } /// Compute cosine similarity between two vectors. @@ -368,6 +615,7 @@ mod tests { #[test] fn test_infer_dimensions() { assert_eq!(infer_dimensions("text-embedding-3-small"), 1536); + assert_eq!(infer_dimensions("gemini-embedding-2-preview"), 3072); assert_eq!(infer_dimensions("all-MiniLM-L6-v2"), 384); assert_eq!(infer_dimensions("nomic-embed-text"), 768); assert_eq!(infer_dimensions("unknown-model"), 1536); // default diff --git a/crates/openfang-runtime/src/mcp.rs b/crates/openfang-runtime/src/mcp.rs index b9f5f3819f..3cb063d257 100644 --- a/crates/openfang-runtime/src/mcp.rs +++ b/crates/openfang-runtime/src/mcp.rs @@ -14,7 +14,6 @@ use rmcp::service::RunningService; use rmcp::{RoleClient, ServiceExt}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use std::sync::Arc; use tracing::{debug, info}; // --------------------------------------------------------------------------- @@ -307,11 +306,8 @@ impl McpConnection { } } - let config = StreamableHttpClientTransportConfig { - uri: Arc::from(url), - custom_headers, - ..Default::default() - }; + let config = + StreamableHttpClientTransportConfig::with_uri(url).custom_headers(custom_headers); let transport = StreamableHttpClientTransport::from_config(config); diff --git a/crates/openfang-runtime/src/web_fetch.rs b/crates/openfang-runtime/src/web_fetch.rs index 81021aefca..7d318a7d8a 100644 --- a/crates/openfang-runtime/src/web_fetch.rs +++ b/crates/openfang-runtime/src/web_fetch.rs @@ -506,7 +506,11 @@ mod tests { assert!(check_ssrf("http://169.254.169.254/latest/meta-data/", &allow).is_err()); // Also verify hostname-based metadata blocks let allow2 = vec!["metadata.google.internal".to_string()]; - assert!(check_ssrf("http://metadata.google.internal/computeMetadata/v1/", &allow2).is_err()); + assert!(check_ssrf( + "http://metadata.google.internal/computeMetadata/v1/", + &allow2 + ) + .is_err()); } #[test] @@ -514,7 +518,7 @@ mod tests { let allow = vec!["*.example.com".to_string()]; assert!(check_ssrf("http://api.example.com", &allow).is_ok()); // Non-matching domain should still go through normal checks - assert!(is_host_allowed("other.net", &allow) == false); + assert!(!is_host_allowed("other.net", &allow)); } #[test] diff --git a/crates/openfang-runtime/src/web_search.rs b/crates/openfang-runtime/src/web_search.rs index 28e92259e3..11b2f5823f 100644 --- a/crates/openfang-runtime/src/web_search.rs +++ b/crates/openfang-runtime/src/web_search.rs @@ -358,7 +358,10 @@ impl WebSearchEngine { let resp = self .client - .get(format!("{}/search", self.config.searxng.url.trim_end_matches('/'))) + .get(format!( + "{}/search", + self.config.searxng.url.trim_end_matches('/') + )) .query(&[ ("q", query), ("format", "json"), @@ -451,7 +454,10 @@ impl WebSearchEngine { let resp = self .client - .get(format!("{}/config", self.config.searxng.url.trim_end_matches('/'))) + .get(format!( + "{}/config", + self.config.searxng.url.trim_end_matches('/') + )) .header("User-Agent", "Mozilla/5.0 (compatible; OpenFangAgent/0.1)") .send() .await