Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions crates/openfang-kernel/src/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -881,12 +881,14 @@ impl OpenFangKernel {
// Auto-detect embedding provider by checking API key env vars in
// priority order. First match wins.
const API_KEY_PROVIDERS: &[(&str, &str)] = &[
("OPENAI_API_KEY", "openai"),
("GROQ_API_KEY", "groq"),
("MISTRAL_API_KEY", "mistral"),
("TOGETHER_API_KEY", "together"),
("GEMINI_API_KEY", "gemini"),
("GOOGLE_API_KEY", "gemini"),
("OPENAI_API_KEY", "openai"),
("GROQ_API_KEY", "groq"),
("MISTRAL_API_KEY", "mistral"),
("TOGETHER_API_KEY", "together"),
("FIREWORKS_API_KEY", "fireworks"),
("COHERE_API_KEY", "cohere"),
("COHERE_API_KEY", "cohere"),
];

let detected_from_key = API_KEY_PROVIDERS
Expand Down Expand Up @@ -1127,8 +1129,7 @@ impl OpenFangKernel {
!= entry.manifest.tool_allowlist
|| disk_manifest.tool_blocklist
!= entry.manifest.tool_blocklist
|| disk_manifest.skills
!= entry.manifest.skills
|| disk_manifest.skills != entry.manifest.skills
|| disk_manifest.mcp_servers
!= entry.manifest.mcp_servers;
if changed {
Expand Down Expand Up @@ -5700,6 +5701,7 @@ fn apply_budget_defaults(
fn default_embedding_model_for_provider(provider: &str) -> &'static str {
match provider {
"openai" => "text-embedding-3-small",
"gemini" | "google" => "gemini-embedding-2-preview",
"groq" => "nomic-embed-text",
"mistral" => "mistral-embed",
"together" => "togethercomputer/m2-bert-80M-8k-retrieval",
Expand Down
258 changes: 253 additions & 5 deletions crates/openfang-runtime/src/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

use async_trait::async_trait;
use openfang_types::model_catalog::{
FIREWORKS_BASE_URL, GROQ_BASE_URL, LMSTUDIO_BASE_URL, MISTRAL_BASE_URL, OLLAMA_BASE_URL,
OPENAI_BASE_URL, TOGETHER_BASE_URL, VLLM_BASE_URL,
FIREWORKS_BASE_URL, GEMINI_BASE_URL, GROQ_BASE_URL, LMSTUDIO_BASE_URL, MISTRAL_BASE_URL,
OLLAMA_BASE_URL, OPENAI_BASE_URL, TOGETHER_BASE_URL, VLLM_BASE_URL,
};
use serde::{Deserialize, Serialize};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tracing::{debug, warn};
use zeroize::Zeroizing;

Expand Down Expand Up @@ -70,6 +70,15 @@ pub struct OpenAIEmbeddingDriver {
dims: usize,
}

pub struct GeminiEmbeddingDriver {
api_key: Zeroizing<String>,
base_url: String,
resource_model: String,
client: reqwest::Client,
dims: usize,
mode: std::sync::atomic::AtomicU8,
}

#[derive(Serialize)]
struct EmbedRequest<'a> {
model: &'a str,
Expand All @@ -86,6 +95,62 @@ struct EmbedData {
embedding: Vec<f32>,
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiEmbedRequest<'a> {
requests: Vec<GeminiEmbedItem<'a>>,
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiEmbedItem<'a> {
model: &'a str,
content: GeminiEmbedContent<'a>,
output_dimensionality: usize,
}

#[derive(Serialize)]
struct GeminiEmbedContent<'a> {
parts: Vec<GeminiEmbedPart<'a>>,
}

#[derive(Serialize)]
struct GeminiEmbedPart<'a> {
text: &'a str,
}

#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiEmbedResponse {
embeddings: Vec<GeminiEmbedValue>,
}

#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiEmbedValue {
values: Vec<f32>,
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiSingleEmbedRequest<'a> {
model: &'a str,
content: GeminiEmbedContent<'a>,
output_dimensionality: usize,
}

#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiSingleEmbedResponse {
embedding: GeminiEmbedValue,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum GeminiEmbeddingMode {
Batch = 0,
Single = 1,
}

impl OpenAIEmbeddingDriver {
/// Create a new OpenAI-compatible embedding driver.
pub fn new(config: EmbeddingConfig) -> Result<Self, EmbeddingError> {
Expand All @@ -102,13 +167,39 @@ impl OpenAIEmbeddingDriver {
}
}

fn normalize_gemini_model_name(model: &str) -> String {
if model.starts_with("models/") {
model.to_string()
} else {
format!("models/{model}")
}
}

impl GeminiEmbeddingDriver {
/// Create a new Gemini embedding driver.
pub fn new(config: EmbeddingConfig) -> Result<Self, EmbeddingError> {
let resource_model = normalize_gemini_model_name(&config.model);
let dims = infer_dimensions(resource_model.trim_start_matches("models/"));

Ok(Self {
api_key: Zeroizing::new(config.api_key),
base_url: config.base_url,
resource_model,
client: reqwest::Client::new(),
dims,
mode: std::sync::atomic::AtomicU8::new(GeminiEmbeddingMode::Batch as u8),
})
}
}

/// Infer embedding dimensions from model name.
fn infer_dimensions(model: &str) -> usize {
match model {
// OpenAI
"text-embedding-3-small" => 1536,
"text-embedding-3-large" => 3072,
"text-embedding-ada-002" => 1536,
"gemini-embedding-2-preview" => 3072,
// Sentence Transformers / local models
"all-MiniLM-L6-v2" => 384,
"all-MiniLM-L12-v2" => 384,
Expand All @@ -120,6 +211,156 @@ fn infer_dimensions(model: &str) -> usize {
}
}

#[async_trait]
impl EmbeddingDriver for GeminiEmbeddingDriver {
async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
if texts.is_empty() {
return Ok(vec![]);
}

match self.mode() {
GeminiEmbeddingMode::Single => self.embed_sequential(texts).await,
GeminiEmbeddingMode::Batch => match self.embed_batch(texts).await {
Ok(embeddings) => Ok(embeddings),
Err(err) if self.should_fallback_to_single(&err) => {
debug!(
error = %err,
"Gemini batch embeddings unavailable; falling back to single embedContent requests"
);
let embeddings = self.embed_sequential(texts).await?;
self.set_mode(GeminiEmbeddingMode::Single);
Ok(embeddings)
}
Err(err) => Err(err),
},
}
}

fn dimensions(&self) -> usize {
self.dims
}
}

impl GeminiEmbeddingDriver {
fn mode(&self) -> GeminiEmbeddingMode {
match self.mode.load(std::sync::atomic::Ordering::Relaxed) {
x if x == GeminiEmbeddingMode::Single as u8 => GeminiEmbeddingMode::Single,
_ => GeminiEmbeddingMode::Batch,
}
}

fn set_mode(&self, mode: GeminiEmbeddingMode) {
self.mode
.store(mode as u8, std::sync::atomic::Ordering::Relaxed);
}

fn should_fallback_to_single(&self, err: &EmbeddingError) -> bool {
match err {
EmbeddingError::Api { status, message } => {
let lower = message.to_ascii_lowercase();
*status == 403
|| *status == 404
|| (*status == 400
&& (lower.contains("batch")
|| lower.contains("not supported")
|| lower.contains("permission_denied")
|| lower.contains("permission denied")
|| lower.contains("method not found")
|| lower.contains("unimplemented")))
}
_ => false,
}
}

async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
let url = format!(
"{}/v1beta/models/{}:batchEmbedContents",
self.base_url.trim_end_matches('/'),
self.resource_model.trim_start_matches("models/")
);
let body = GeminiEmbedRequest {
requests: texts
.iter()
.map(|text| GeminiEmbedItem {
model: &self.resource_model,
content: GeminiEmbedContent {
parts: vec![GeminiEmbedPart { text }],
},
output_dimensionality: self.dims,
})
.collect(),
};

let resp = self
.client
.post(&url)
.header("x-goog-api-key", self.api_key.as_str())
.json(&body)
.send()
.await
.map_err(|e| EmbeddingError::Http(e.to_string()))?;

Self::parse_gemini_response(resp)
.await
.map(|data: GeminiEmbedResponse| {
data.embeddings.into_iter().map(|d| d.values).collect()
})
}

async fn embed_sequential(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
let mut embeddings = Vec::with_capacity(texts.len());
for text in texts {
embeddings.push(self.embed_single(text).await?);
}
Ok(embeddings)
}

async fn embed_single(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
let url = format!(
"{}/v1beta/{}:embedContent",
self.base_url.trim_end_matches('/'),
self.resource_model
);
let body = GeminiSingleEmbedRequest {
model: &self.resource_model,
content: GeminiEmbedContent {
parts: vec![GeminiEmbedPart { text }],
},
output_dimensionality: self.dims,
};

let resp = self
.client
.post(&url)
.header("x-goog-api-key", self.api_key.as_str())
.json(&body)
.send()
.await
.map_err(|e| EmbeddingError::Http(e.to_string()))?;

Self::parse_gemini_response(resp)
.await
.map(|data: GeminiSingleEmbedResponse| data.embedding.values)
}

async fn parse_gemini_response<T: DeserializeOwned>(
resp: reqwest::Response,
) -> Result<T, EmbeddingError> {
let status = resp.status().as_u16();
if status != 200 {
let body_text = resp.text().await.unwrap_or_default();
return Err(EmbeddingError::Api {
status,
message: body_text,
});
}

resp.json()
.await
.map_err(|e| EmbeddingError::Parse(e.to_string()))
}
}

#[async_trait]
impl EmbeddingDriver for OpenAIEmbeddingDriver {
async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
Expand Down Expand Up @@ -213,6 +454,7 @@ pub fn create_embedding_driver(
})
.unwrap_or_else(|| match provider {
"openai" => OPENAI_BASE_URL.to_string(),
"gemini" | "google" => GEMINI_BASE_URL.to_string(),
"groq" => GROQ_BASE_URL.to_string(),
"together" => TOGETHER_BASE_URL.to_string(),
"fireworks" => FIREWORKS_BASE_URL.to_string(),
Expand Down Expand Up @@ -245,8 +487,13 @@ pub fn create_embedding_driver(
base_url,
};

let driver = OpenAIEmbeddingDriver::new(config)?;
Ok(Box::new(driver))
if provider == "gemini" || provider == "google" {
let driver = GeminiEmbeddingDriver::new(config)?;
Ok(Box::new(driver))
} else {
let driver = OpenAIEmbeddingDriver::new(config)?;
Ok(Box::new(driver))
}
}

/// Compute cosine similarity between two vectors.
Expand Down Expand Up @@ -368,6 +615,7 @@ mod tests {
#[test]
fn test_infer_dimensions() {
assert_eq!(infer_dimensions("text-embedding-3-small"), 1536);
assert_eq!(infer_dimensions("gemini-embedding-2-preview"), 3072);
assert_eq!(infer_dimensions("all-MiniLM-L6-v2"), 384);
assert_eq!(infer_dimensions("nomic-embed-text"), 768);
assert_eq!(infer_dimensions("unknown-model"), 1536); // default
Expand Down
8 changes: 2 additions & 6 deletions crates/openfang-runtime/src/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use rmcp::service::RunningService;
use rmcp::{RoleClient, ServiceExt};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info};

// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -307,11 +306,8 @@ impl McpConnection {
}
}

let config = StreamableHttpClientTransportConfig {
uri: Arc::from(url),
custom_headers,
..Default::default()
};
let config =
StreamableHttpClientTransportConfig::with_uri(url).custom_headers(custom_headers);

let transport = StreamableHttpClientTransport::from_config(config);

Expand Down
Loading