diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..c3e3e29 --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,105 @@ +# RsChat Architecture + +## Overview + +RsChat is a real-time chat application that provides resumable streaming conversations with LLM providers. The architecture is designed for high performance, scalability across multiple server instances, and resilient streaming that can survive network interruptions. + +## Core Architecture + +### Frontend (React/TypeScript) +- **Location**: `web/` +- **Streaming**: Server-Sent Events (SSE) +- **State Management**: React Context for streaming state (`web/src/lib/context/StreamingContext.tsx`) +- **API Integration**: Type-safe API calls (generated from OpenAPI: `web/src/lib/api/types.d.ts`) + +### Backend (Rust/Rocket) +- **Location**: `server/` +- **Framework**: Rocket with async/await support +- **Database**: PostgreSQL for persistent storage +- **Cache/Streaming**: Redis for stream management and caching + +## LLM Streaming Architecture + +### Dual-Stream Approach + +RsChat uses a hybrid streaming architecture that provides both real-time performance and cross-instance resumability: + +1. **Server**: Redis Streams for resumability and multi-instance support +2. **Client**: Server-Sent Events (SSE) read from the Redis streams + +### Key Components + +#### 1. LlmStreamWriter (`server/src/stream/llm_writer.rs`) + +The core component that processes LLM provider streams and manages Redis stream output. + +**Key Features:** +- **Batching**: Accumulates chunks from the provider stream, up to a max length or timeout +- **Background Pings**: Sends regular keepalive pings +- **Database Integration**: Saves final responses to PostgreSQL + +#### 2. Redis and SSE Stream Structure + +**Redis Key for Chat Streams**: `user:{user_id}:chat:{session_id}` + +**Chat Stream Message Types**: +- `start`: Stream initialization +- `text`: Accumulated text chunks +- `tool_call`: LLM tool invocations (JSON stringified) +- `error`: Error messages +- `ping`: Keepalive messages +- `end`: Stream completion +- `cancel`: Stream cancellation + +#### 3. Stream Lifecycle + +``` +Client Request → SSE Connection → LlmStreamWriter.create() + ↓ +LLM Provider Stream → Batching Data Chunks → Redis XADD + ↓ +Background Ping Task (intervals) + ↓ +Stream End → Database Save → Redis DEL +``` + + +### Resumability Features + +#### Cross-Instance Support +- Redis streams provide shared state across server instances +- Background ping tasks maintain stream liveness +- Stream cancellation detected via Redis XADD failures + +## Data Flow + +### 1. New Chat Request +``` +Client → POST /api/chat/{session_id} + → Send request to LLM Provider + → SSE Response Stream created + → LlmStreamWriter.create() + → Redis Stream created +``` + +### 2. Stream Processing +``` +LLM Chunk → Process text, tool calls, usage, and error chunks + → Batching Logic + → Redis XADD (if conditions met) + → Continue SSE Stream +``` + +### 3. Stream Completion +``` +LLM End → Final Database Save + → Redis Stream End Event + → Redis Stream Cleanup + → SSE Connection Close +``` + +### 4. Reconnection/Resume +``` +Client Reconnect → Check ongoing streams via GET /api/chat/streams + → Reconnect to stream (if active) +``` diff --git a/server/Cargo.lock b/server/Cargo.lock index 6af3994..860b6fd 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -373,16 +373,17 @@ version = "0.6.0" dependencies = [ "aes-gcm", "astral-tokio-tar", - "async-stream", "bollard", "chrono", "const_format", + "deadpool", "diesel", "diesel-async", "diesel-derive-enum", "diesel_as_jsonb", "diesel_async_migrations", "dotenvy", + "dyn-clone", "enum-iterator", "fred", "hex", @@ -397,7 +398,6 @@ dependencies = [ "serde", "serde_json", "subst", - "tempfile", "thiserror", "tokio", "tokio-stream", @@ -623,6 +623,9 @@ name = "deadpool-runtime" version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +dependencies = [ + "tokio", +] [[package]] name = "deranged" diff --git a/server/Cargo.toml b/server/Cargo.toml index 1ae19b4..3185807 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -18,10 +18,10 @@ strip = true [dependencies] aes-gcm = "0.10.3" astral-tokio-tar = "0.5.2" -async-stream = "0.3.6" bollard = { version = "0.19.1", features = ["ssl"] } chrono = { version = "0.4.41", features = ["serde"] } const_format = "0.2.34" +deadpool = { version = "0.12.2", features = ["rt_tokio_1"] } diesel = { version = "2.2.10", features = [ "postgres", "chrono", @@ -33,8 +33,12 @@ diesel-derive-enum = { version = "3.0.0-beta.1", features = ["postgres"] } diesel_as_jsonb = "1.0.1" diesel_async_migrations = "0.15.0" dotenvy = "0.15.7" +dyn-clone = "1.0.19" enum-iterator = "2.1.0" -fred = { version = "10.1.0", default-features = false, features = ["i-keys"] } +fred = { version = "10.1.0", default-features = false, features = [ + "i-keys", + "i-streams", +] } hex = "0.4.3" jsonschema = { version = "0.30.0", default-features = false } rand = "0.9.1" @@ -54,7 +58,6 @@ schemars = { version = "0.8.22", features = ["chrono", "uuid1"] } serde = { version = "1.0.219" } serde_json = "1.0.140" subst = { version = "0.3.8", features = ["json"] } -tempfile = "3.20.0" thiserror = "2.0.12" tokio = { version = "1.45.1" } tokio-stream = "0.1.17" diff --git a/server/src/api.rs b/server/src/api.rs index 45183b1..8033fe7 100644 --- a/server/src/api.rs +++ b/server/src/api.rs @@ -1,6 +1,7 @@ mod api_key; mod auth; mod chat; +mod info; mod provider; mod secret; mod session; @@ -9,6 +10,7 @@ mod tool; pub use api_key::get_routes as api_key_routes; pub use auth::get_routes as auth_routes; pub use chat::get_routes as chat_routes; +pub use info::get_routes as info_routes; pub use provider::get_routes as provider_routes; pub use secret::get_routes as secret_routes; pub use session::get_routes as session_routes; diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index 46a24ec..fbc44c5 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -1,8 +1,8 @@ use std::{borrow::Cow, pin::Pin}; use rocket::{ - futures::{Stream, StreamExt}, - post, + futures::{stream, Stream, StreamExt}, + get, post, response::stream::{Event, EventStream}, serde::json::Json, Route, State, @@ -11,6 +11,7 @@ use rocket_okapi::{ okapi::openapi3::OpenApi, openapi, openapi_get_routes_spec, settings::OpenApiSettings, }; use schemars::JsonSchema; +use tokio_stream::wrappers::ReceiverStream; use uuid::Uuid; use crate::{ @@ -18,21 +19,47 @@ use crate::{ auth::ChatRsUserId, db::{ models::{ - ChatRsMessageMeta, ChatRsMessageRole, ChatRsProviderType, ChatRsSessionMeta, + AssistantMeta, ChatRsMessageMeta, ChatRsMessageRole, ChatRsSessionMeta, NewChatRsMessage, UpdateChatRsSession, }, services::{ChatDbService, ProviderDbService, ToolDbService}, DbConnection, DbPool, }, errors::ApiError, - provider::{build_llm_provider_api, LlmApiProviderSharedOptions, LlmTool}, - redis::RedisClient, - tools::SendChatToolInput, - utils::{generate_title, Encryptor, StoredChatRsStream}, + provider::{build_llm_provider_api, LlmError, LlmProviderOptions}, + redis::{ExclusiveRedisClient, RedisClient}, + stream::{ + cancel_current_chat_stream, check_chat_stream_exists, get_current_chat_streams, + LastEventId, LlmStreamWriter, SseStreamReader, + }, + tools::{get_llm_tools_from_input, SendChatToolInput}, + utils::{generate_title, Encryptor}, }; pub fn get_routes(settings: &OpenApiSettings) -> (Vec, OpenApi) { - openapi_get_routes_spec![settings: send_chat_stream] + openapi_get_routes_spec![ + settings: get_chat_streams, + send_chat_stream, + connect_to_chat_stream, + cancel_chat_stream, + ] +} + +#[derive(Debug, JsonSchema, serde::Serialize)] +pub struct GetChatStreamsResponse { + sessions: Vec, +} + +/// # Get chat streams +/// Get the ongoing chat response streams +#[openapi(tag = "Chat")] +#[get("/streams")] +pub async fn get_chat_streams( + user_id: ChatRsUserId, + redis: RedisClient, +) -> Result, ApiError> { + let sessions = get_current_chat_streams(&redis, &user_id).await?; + Ok(Json(GetChatStreamsResponse { sessions })) } #[derive(JsonSchema, serde::Deserialize)] @@ -41,13 +68,21 @@ pub struct SendChatInput<'a> { message: Option>, /// The ID of the provider to chat with provider_id: i32, - /// Provider options - provider_options: LlmApiProviderSharedOptions, + /// Configuration for the provider + options: LlmProviderOptions, /// Configuration of tools available to the assistant tools: Option, } -/// Send a chat message and stream the response +#[derive(JsonSchema, serde::Serialize)] +pub struct SendChatResponse { + message: &'static str, + url: String, +} + +/// # Start chat stream +/// Send a chat message and start the streamed assistant response. After the response +/// has started, use the `//stream` endpoint to connect to the SSE stream. #[openapi(tag = "Chat")] #[post("/", data = "")] pub async fn send_chat_stream( @@ -55,26 +90,31 @@ pub async fn send_chat_stream( db_pool: &State, mut db: DbConnection, redis: RedisClient, + redis_writer: ExclusiveRedisClient, encryptor: &State, http_client: &State, session_id: Uuid, mut input: Json>, -) -> Result + Send>>>, ApiError> { - // Check session exists and user is owner, get message history - let (session, mut current_messages) = ChatDbService::new(&mut db) +) -> Result, ApiError> { + // Check that we aren't already streaming a response for this session + if check_chat_stream_exists(&redis, &user_id, &session_id).await? { + return Err(LlmError::AlreadyStreaming)?; + } + + // Get session and message history + let (session, mut messages) = ChatDbService::new(&mut db) .get_session_with_messages(&user_id, &session_id) .await?; - // Build the chat provider + // Build the LLM provider let (provider, api_key_secret) = ProviderDbService::new(&mut db) .get_by_id(&user_id, input.provider_id) .await?; - let provider_type: ChatRsProviderType = provider.provider_type.as_str().try_into()?; let api_key = api_key_secret .map(|secret| encryptor.decrypt_string(&secret.ciphertext, &secret.nonce)) .transpose()?; let provider_api = build_llm_provider_api( - &provider_type, + &provider.provider_type.as_str().try_into()?, provider.base_url.as_deref(), api_key.as_deref(), &http_client, @@ -82,40 +122,24 @@ pub async fn send_chat_stream( )?; // Get the user's chosen tools - let mut llm_tools: Option> = None; - let mut tool_db_service = ToolDbService::new(&mut db); - if let Some(system_tool_input) = input.tools.as_ref().and_then(|t| t.system.as_ref()) { - let system_tools = tool_db_service.find_system_tools_by_user(&user_id).await?; - let system_llm_tools = system_tool_input.get_llm_tools(&system_tools)?; - llm_tools.get_or_insert_default().extend(system_llm_tools); - } - if let Some(external_apis_input) = input.tools.as_ref().and_then(|t| t.external_apis.as_ref()) { - let external_api_tools = tool_db_service - .find_external_api_tools_by_user(&user_id) - .await?; - for tool_input in external_apis_input { - let api_llm_tools = tool_input.into_llm_tools(&external_api_tools)?; - llm_tools.get_or_insert_default().extend(api_llm_tools); - } + let mut tools = None; + if let Some(tool_input) = input.tools.as_ref() { + let mut tool_db_service = ToolDbService::new(&mut db); + tools = Some(get_llm_tools_from_input(&user_id, tool_input, &mut tool_db_service).await?); } - // Save user message and generate session title if needed + // Generate session title if needed, and save user message to database if let Some(user_message) = &input.message { - if current_messages.is_empty() && session.title == DEFAULT_SESSION_TITLE { + if messages.is_empty() && session.title == DEFAULT_SESSION_TITLE { generate_title( &user_id, &session_id, &user_message, - provider_type, + &provider_api, &provider.default_model, - provider.base_url.as_deref(), - api_key.clone(), - &http_client, - &redis, db_pool, ); } - let new_message = ChatDbService::new(&mut db) .save_message(NewChatRsMessage { content: user_message, @@ -124,49 +148,118 @@ pub async fn send_chat_stream( meta: ChatRsMessageMeta::default(), }) .await?; - current_messages.push(new_message); + messages.push(new_message); } - // Update session metadata + // Update session metadata if needed if let Some(tool_input) = input.tools.take() { if session .meta .tool_config .is_none_or(|config| config != tool_input) { + let meta = ChatRsSessionMeta::new(Some(tool_input)); + let data = UpdateChatRsSession { + meta: Some(&meta), + ..Default::default() + }; ChatDbService::new(&mut db) - .update_session( - &user_id, - &session_id, - UpdateChatRsSession { - meta: Some(&ChatRsSessionMeta { - tool_config: Some(tool_input), - }), - ..Default::default() - }, - ) + .update_session(&user_id, &session_id, data) .await?; } } - // Get the provider's stream response and wrap it in our StoredChatRsStream - let stream = StoredChatRsStream::new( - provider_api - .chat_stream(current_messages, llm_tools, &input.provider_options) - .await?, - input.provider_id, - input.provider_options.clone(), - db_pool.inner().clone(), - redis.clone(), - Some(session_id), - ); - - // Start streaming - let event_stream = stream - .map(|result| match result { - Ok(message) => Event::data(format!(" {message}")).event("chat"), - Err(err) => Event::data(err.to_string()).event("error"), - }) - .boxed(); - Ok(EventStream::from(event_stream)) + // Get the provider's stream response + let stream = provider_api + .chat_stream(messages, tools, &input.options) + .await?; + let provider_id = input.provider_id; + let provider_options = input.options.clone(); + + // Create the Redis stream + let mut stream_writer = LlmStreamWriter::new(redis_writer, &user_id, &session_id); + stream_writer.start().await?; + + // Spawn a task to stream and save the response + tokio::spawn(async move { + let (text, tool_calls, usage, errors, cancelled) = stream_writer.process(stream).await; + let assistant_meta = AssistantMeta { + provider_id, + provider_options: Some(provider_options), + tool_calls, + usage, + errors, + partial: cancelled.then_some(true), + }; + let db_result = ChatDbService::new(&mut db) + .save_message(NewChatRsMessage { + session_id: &session_id, + role: ChatRsMessageRole::Assistant, + content: &text.unwrap_or_default(), + meta: ChatRsMessageMeta::new_assistant(assistant_meta), + }) + .await; + if let Err(err) = db_result { + rocket::error!("Failed to save assistant message: {}", err); + } + if !cancelled { + stream_writer.end().await.ok(); + } + }); + + Ok(Json(SendChatResponse { + message: "Stream started", + url: format!("/api/chat/{}/stream", session_id), + })) +} + +/// # Connect to chat stream +/// Connect to an ongoing chat stream and stream the assistant response +#[openapi(tag = "Chat")] +#[get("//stream")] +pub async fn connect_to_chat_stream( + user_id: ChatRsUserId, + redis_reader: ExclusiveRedisClient, + session_id: Uuid, + start_event_id: Option, +) -> Result + Send>>>, ApiError> { + let stream_reader = SseStreamReader::new(redis_reader); + + // Get all previous events from the Redis stream, and return them if we're already at the end of the stream + let (prev_events, last_event_id, is_end) = stream_reader + .get_prev_events(&user_id, &session_id, start_event_id.as_deref()) + .await?; + let prev_events_stream = stream::iter(prev_events); + if is_end { + return Ok(EventStream::from(prev_events_stream.boxed())); + } + + // Spawn a task to receive new events from Redis and add them to this channel + let (tx, rx) = tokio::sync::mpsc::channel::(50); + tokio::spawn(async move { + stream_reader + .stream(&user_id, &session_id, &last_event_id, &tx) + .await; + drop(tx); + }); + + // Send stream to client + let stream = prev_events_stream.chain(ReceiverStream::new(rx)).boxed(); + Ok(EventStream::from(stream)) +} + +/// # Cancel chat stream +/// Cancel an ongoing chat stream +#[openapi(tag = "Chat")] +#[post("//cancel")] +pub async fn cancel_chat_stream( + user_id: ChatRsUserId, + redis: RedisClient, + session_id: Uuid, +) -> Result<(), ApiError> { + if !check_chat_stream_exists(&redis, &user_id, &session_id).await? { + return Err(LlmError::StreamNotFound)?; + } + cancel_current_chat_stream(&redis, &user_id, &session_id).await?; + Ok(()) } diff --git a/server/src/api/info.rs b/server/src/api/info.rs new file mode 100644 index 0000000..e16ea32 --- /dev/null +++ b/server/src/api/info.rs @@ -0,0 +1,57 @@ +use rocket::{get, serde::json::Json, Route, State}; +use rocket_okapi::{ + okapi::openapi3::OpenApi, openapi, openapi_get_routes_spec, settings::OpenApiSettings, +}; +use schemars::JsonSchema; +use serde::Serialize; + +use crate::{auth::ChatRsUserId, config::AppConfig, errors::ApiError, redis::ExclusiveClientPool}; + +pub fn get_routes(settings: &OpenApiSettings) -> (Vec, OpenApi) { + openapi_get_routes_spec![ + settings: get_info + ] +} + +#[derive(Debug, Serialize, JsonSchema)] +struct InfoResponse { + version: String, + url: String, + redis: RedisStats, +} + +#[derive(Debug, Serialize, JsonSchema)] +struct RedisStats { + /// Number of static connections + r#static: usize, + /// Number of current streaming connections + streaming: usize, + /// Number of available streaming connections + streaming_available: usize, + /// Maximum number of streaming connections + streaming_max: usize, +} + +/// # Get info +/// Get information about the server +#[openapi] +#[get("/")] +async fn get_info( + _user_id: ChatRsUserId, + app_config: &State, + redis_pool: &State, +) -> Result, ApiError> { + let redis_status = redis_pool.status(); + let redis_stats = RedisStats { + r#static: app_config.redis_pool.unwrap_or(4), + streaming: redis_status.size, + streaming_max: redis_status.max_size, + streaming_available: redis_status.available, + }; + + Ok(Json(InfoResponse { + version: format!("v{}", env!("CARGO_PKG_VERSION")), + url: app_config.server_address.clone(), + redis: redis_stats, + })) +} diff --git a/server/src/api/session.rs b/server/src/api/session.rs index a67573d..6be63d3 100644 --- a/server/src/api/session.rs +++ b/server/src/api/session.rs @@ -1,4 +1,3 @@ -use fred::prelude::KeysInterface; use rocket::{delete, get, patch, post, serde::json::Json, Route}; use rocket_okapi::{ okapi::openapi3::OpenApi, openapi, openapi_get_routes_spec, settings::OpenApiSettings, @@ -10,16 +9,12 @@ use uuid::Uuid; use crate::{ auth::ChatRsUserId, db::{ - models::{ - AssistantMeta, ChatRsMessage, ChatRsMessageMeta, ChatRsMessageRole, ChatRsSession, - NewChatRsSession, UpdateChatRsSession, - }, + models::{ChatRsMessage, ChatRsSession, NewChatRsSession, UpdateChatRsSession}, services::ChatDbService, DbConnection, }, errors::ApiError, - redis::RedisClient, - utils::{SessionSearchResult, CHAT_CACHE_KEY_PREFIX}, + utils::SessionSearchResult, }; pub fn get_routes(settings: &OpenApiSettings) -> (Vec, OpenApi) { @@ -84,34 +79,12 @@ struct GetSessionResponse { async fn get_session( user_id: ChatRsUserId, mut db: DbConnection, - redis: RedisClient, session_id: Uuid, ) -> Result, ApiError> { - let (session, mut messages) = ChatDbService::new(&mut db) + let (session, messages) = ChatDbService::new(&mut db) .get_session_with_messages(&user_id, &session_id) .await?; - // Check for a cached response if the session is interrupted - let cached_response: Option = redis - .get(format!("{}{}", CHAT_CACHE_KEY_PREFIX, &session_id)) - .await?; - if let Some(interrupted_response) = cached_response { - messages.push(ChatRsMessage { - id: Uuid::new_v4(), - session_id, - role: ChatRsMessageRole::Assistant, - content: interrupted_response, - created_at: chrono::Utc::now(), - meta: ChatRsMessageMeta { - assistant: Some(AssistantMeta { - partial: Some(true), - ..Default::default() - }), - ..Default::default() - }, - }); - } - Ok(Json(GetSessionResponse { session, messages })) } diff --git a/server/src/api/tool.rs b/server/src/api/tool.rs index b674d63..b9f9240 100644 --- a/server/src/api/tool.rs +++ b/server/src/api/tool.rs @@ -256,6 +256,7 @@ async fn execute_tool( tool_call: Some(ChatRsExecutedToolCall { id: tool_call.id, tool_id: tool_call.tool_id, + tool_name: tool_call.tool_name, tool_type: tool_call.tool_type, response_format: format, is_error, diff --git a/server/src/auth/session.rs b/server/src/auth/session.rs index 699ba6a..0e9a6c9 100644 --- a/server/src/auth/session.rs +++ b/server/src/auth/session.rs @@ -1,11 +1,11 @@ -use std::{ops::Deref, time::Duration}; +use std::ops::Deref; use chrono::Utc; use rocket::fairing::AdHoc; use rocket_flex_session::{storage::redis::RedisFredStorage, RocketFlexSession}; use uuid::Uuid; -use crate::config::get_app_config; +use crate::{config::get_app_config, redis::build_redis_pool}; const USER_ID_KEY: &str = "user_id"; const USER_ID_BYTES_KEY: &str = "user_id_bytes"; @@ -74,16 +74,7 @@ pub fn setup_session() -> AdHoc { let app_config = get_app_config(&rocket); let config = fred::prelude::Config::from_url(&app_config.redis_url) .expect("RS_CHAT_REDIS_URL should be valid Redis URL"); - let session_redis_pool = fred::prelude::Builder::from_config(config) - .with_connection_config(|config| { - config.connection_timeout = Duration::from_secs(4); - config.tcp = fred::prelude::TcpConfig { - nodelay: Some(true), - ..Default::default() - }; - }) - .build_pool(app_config.redis_pool.unwrap_or(2)) - .expect("Failed to build Redis session pool"); + let session_redis_pool = build_redis_pool(config, 2).expect("Failed to build Redis pool"); let session_fairing: RocketFlexSession = RocketFlexSession::builder() .with_options(|opt| { opt.cookie_name = "auth_rs_chat".to_string(); diff --git a/server/src/config.rs b/server/src/config.rs index 18238ca..0fc5183 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -20,8 +20,10 @@ pub struct AppConfig { pub database_url: String, /// Redis connection URL pub redis_url: String, - /// Redis pool size (default: 2) + /// Redis static pool size (default: 4) pub redis_pool: Option, + /// Maximum number of concurrent Redis connections for streaming (default: 20) + pub max_streams: Option, } /// Get the server configuration variables from Rocket diff --git a/server/src/db/models/chat.rs b/server/src/db/models/chat.rs index a32ccc5..ad3db9e 100644 --- a/server/src/db/models/chat.rs +++ b/server/src/db/models/chat.rs @@ -10,7 +10,7 @@ use uuid::Uuid; use crate::{ db::models::{ChatRsExecutedToolCall, ChatRsToolCall, ChatRsUser}, - provider::{LlmApiProviderSharedOptions, LlmUsage}, + provider::{LlmProviderOptions, LlmUsage}, tools::SendChatToolInput, }; @@ -33,6 +33,11 @@ pub struct ChatRsSessionMeta { #[serde(skip_serializing_if = "Option::is_none")] pub tool_config: Option, } +impl ChatRsSessionMeta { + pub fn new(tool_config: Option) -> Self { + Self { tool_config } + } +} #[derive(Insertable)] #[diesel(table_name = super::schema::chat_sessions)] @@ -79,6 +84,14 @@ pub struct ChatRsMessageMeta { #[serde(skip_serializing_if = "Option::is_none")] pub tool_call: Option, } +impl ChatRsMessageMeta { + pub fn new_assistant(assistant: AssistantMeta) -> Self { + Self { + assistant: Some(assistant), + tool_call: None, + } + } +} #[derive(Debug, Default, JsonSchema, Serialize, Deserialize)] pub struct AssistantMeta { @@ -86,13 +99,16 @@ pub struct AssistantMeta { pub provider_id: i32, /// Options passed to the LLM provider #[serde(skip_serializing_if = "Option::is_none")] - pub provider_options: Option, + pub provider_options: Option, /// The tool calls requested by the assistant #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, /// Provider usage information #[serde(skip_serializing_if = "Option::is_none")] pub usage: Option, + /// Errors encountered during message generation + #[serde(skip_serializing_if = "Option::is_none")] + pub errors: Option>, /// Whether this is a partial and/or interrupted message #[serde(skip_serializing_if = "Option::is_none")] pub partial: Option, diff --git a/server/src/db/models/provider.rs b/server/src/db/models/provider.rs index 6ecdbb3..6ce9af7 100644 --- a/server/src/db/models/provider.rs +++ b/server/src/db/models/provider.rs @@ -51,6 +51,7 @@ pub struct UpdateChatRsProvider<'a> { pub enum ChatRsProviderType { Anthropic, Openai, + Ollama, Lorem, } @@ -61,6 +62,7 @@ impl TryFrom<&str> for ChatRsProviderType { match value { "anthropic" => Ok(ChatRsProviderType::Anthropic), "openai" => Ok(ChatRsProviderType::Openai), + "ollama" => Ok(ChatRsProviderType::Ollama), "lorem" => Ok(ChatRsProviderType::Lorem), _ => Err(LlmError::UnsupportedProvider), } @@ -72,6 +74,7 @@ impl From<&ChatRsProviderType> for &str { match value { ChatRsProviderType::Anthropic => "anthropic", ChatRsProviderType::Openai => "openai", + ChatRsProviderType::Ollama => "ollama", ChatRsProviderType::Lorem => "lorem", } } diff --git a/server/src/db/models/tool.rs b/server/src/db/models/tool.rs index 19f899e..a328ea5 100644 --- a/server/src/db/models/tool.rs +++ b/server/src/db/models/tool.rs @@ -55,7 +55,7 @@ pub struct NewChatRsExternalApiTool<'r> { } /// A tool call requested by the provider -#[derive(Debug, JsonSchema, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, JsonSchema, serde::Serialize, serde::Deserialize)] pub struct ChatRsToolCall { /// ID of the tool call pub id: String, @@ -77,6 +77,9 @@ pub struct ChatRsExecutedToolCall { pub id: String, /// ID of the tool used pub tool_id: Uuid, + /// Name of the tool used + #[serde(default)] + pub tool_name: String, /// Type of the tool used #[serde(default)] pub tool_type: LlmToolType, diff --git a/server/src/errors.rs b/server/src/errors.rs index 958210f..7adefed 100644 --- a/server/src/errors.rs +++ b/server/src/errors.rs @@ -1,3 +1,4 @@ +use diesel_async::pooled_connection::deadpool; use rocket::{ catch, catchers, response::{self, Responder}, @@ -13,6 +14,8 @@ use crate::{provider::LlmError, tools::ToolError}; pub enum ApiError { #[error(transparent)] Db(#[from] diesel::result::Error), + #[error(transparent)] + DbPool(#[from] deadpool::PoolError), #[error("Authentication error: {0}")] Authentication(String), #[error("Redis error: {0}")] diff --git a/server/src/lib.rs b/server/src/lib.rs index 473e867..828ce99 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -6,6 +6,7 @@ pub mod errors; pub mod provider; pub mod provider_models; pub mod redis; +pub mod stream; pub mod tools; pub mod utils; pub mod web; @@ -40,6 +41,7 @@ pub fn build_rocket() -> rocket::Rocket { mount_endpoints_and_merged_docs! { server, "/api", openapi_settings, "/" => openapi_get_routes_spec![health], + "/info" => api::info_routes(&openapi_settings), "/auth" => api::auth_routes(&openapi_settings), "/provider" => api::provider_routes(&openapi_settings), "/session" => api::session_routes(&openapi_settings), diff --git a/server/src/provider.rs b/server/src/provider.rs index 2befe4d..bc4fb32 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -2,17 +2,23 @@ pub mod anthropic; pub mod lorem; +pub mod ollama; pub mod openai; +mod utils; use std::pin::Pin; +use dyn_clone::DynClone; use rocket::{async_trait, futures::Stream}; use schemars::JsonSchema; use uuid::Uuid; use crate::{ db::models::{ChatRsMessage, ChatRsProviderType, ChatRsToolCall}, - provider::{anthropic::AnthropicProvider, lorem::LoremProvider, openai::OpenAIProvider}, + provider::{ + anthropic::AnthropicProvider, lorem::LoremProvider, ollama::OllamaProvider, + openai::OpenAIProvider, + }, provider_models::LlmModel, }; @@ -24,18 +30,22 @@ pub const DEFAULT_TEMPERATURE: f32 = 0.7; pub enum LlmError { #[error("Missing API key")] MissingApiKey, - #[error("Lorem ipsum error: {0}")] - LoremError(&'static str), - #[error("Anthropic error: {0}")] - AnthropicError(String), - #[error("OpenAI error: {0}")] - OpenAIError(String), + #[error("Provider error: {0}")] + ProviderError(String), #[error("models.dev error: {0}")] ModelsDevError(String), #[error("No chat response")] NoResponse, #[error("Unsupported provider")] UnsupportedProvider, + #[error("Already streaming a response for this session")] + AlreadyStreaming, + #[error("No stream found, or the stream was cancelled")] + StreamNotFound, + #[error("Missing event in stream")] + NoStreamEvent, + #[error("Client disconnected")] + ClientDisconnected, #[error("Encryption error")] EncryptionError, #[error("Decryption error")] @@ -44,16 +54,45 @@ pub enum LlmError { Redis(#[from] fred::error::Error), } +/// LLM errors during streaming +#[derive(Debug, thiserror::Error)] +pub enum LlmStreamError { + #[error("Provider error: {0}")] + ProviderError(String), + #[error("Failed to parse event: {0}")] + Parsing(#[from] serde_json::Error), + #[error("Failed to decode response: {0}")] + Decoding(#[from] tokio_util::codec::LinesCodecError), + #[error("Timeout waiting for provider response")] + StreamTimeout, + #[error("Stream was cancelled")] + StreamCancelled, + #[error("Redis error: {0}")] + Redis(#[from] fred::error::Error), +} + +/// Shared stream response type for LLM providers +pub type LlmStream = Pin + Send>>; + +/// Shared stream chunk result type for LLM providers +pub type LlmStreamChunkResult = Result; + /// A streaming chunk of data from the LLM provider -#[derive(Default)] -pub struct LlmStreamChunk { - pub text: Option, - pub tool_calls: Option>, - pub usage: Option, +pub enum LlmStreamChunk { + Text(String), + ToolCalls(Vec), + PendingToolCall(LlmPendingToolCall), + Usage(LlmUsage), +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct LlmPendingToolCall { + pub index: usize, + pub tool_name: String, } /// Usage stats from the LLM provider -#[derive(Debug, JsonSchema, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] pub struct LlmUsage { pub input_tokens: Option, pub output_tokens: Option, @@ -62,12 +101,9 @@ pub struct LlmUsage { pub cost: Option, } -/// Shared stream type for LLM providers -pub type LlmApiStream = Pin> + Send>>; - /// Shared configuration for LLM provider requests -#[derive(Clone, Debug, JsonSchema, serde::Serialize, serde::Deserialize)] -pub struct LlmApiProviderSharedOptions { +#[derive(Clone, Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] +pub struct LlmProviderOptions { pub model: String, pub temperature: Option, pub max_tokens: Option, @@ -95,34 +131,31 @@ pub enum LlmToolType { /// Unified API for LLM providers #[async_trait] -pub trait LlmApiProvider { +pub trait LlmApiProvider: Send + Sync + DynClone { /// Stream a chat response from the provider async fn chat_stream( &self, messages: Vec, tools: Option>, - options: &LlmApiProviderSharedOptions, - ) -> Result; + options: &LlmProviderOptions, + ) -> Result; /// Submit a prompt to the provider (not streamed) - async fn prompt( - &self, - message: &str, - options: &LlmApiProviderSharedOptions, - ) -> Result; + async fn prompt(&self, message: &str, options: &LlmProviderOptions) + -> Result; /// List available models from the provider async fn list_models(&self) -> Result, LlmError>; } /// Build the LLM API to make calls to the provider -pub fn build_llm_provider_api<'a>( +pub fn build_llm_provider_api( provider_type: &ChatRsProviderType, - base_url: Option<&'a str>, - api_key: Option<&'a str>, + base_url: Option<&str>, + api_key: Option<&str>, http_client: &reqwest::Client, - redis: &'a fred::prelude::Client, -) -> Result, LlmError> { + redis: &fred::clients::Client, +) -> Result, LlmError> { match provider_type { ChatRsProviderType::Openai => Ok(Box::new(OpenAIProvider::new( http_client, @@ -135,6 +168,10 @@ pub fn build_llm_provider_api<'a>( redis, api_key.ok_or(LlmError::MissingApiKey)?, ))), + ChatRsProviderType::Ollama => Ok(Box::new(OllamaProvider::new( + http_client, + base_url.unwrap_or("http://localhost:11434"), + ))), ChatRsProviderType::Lorem => Ok(Box::new(LoremProvider::new())), } } diff --git a/server/src/provider/anthropic.rs b/server/src/provider/anthropic.rs index dc4a631..bcea95a 100644 --- a/server/src/provider/anthropic.rs +++ b/server/src/provider/anthropic.rs @@ -1,267 +1,62 @@ //! Anthropic LLM provider -use std::collections::HashMap; +mod request; +mod response; -use rocket::async_trait; -use serde::{Deserialize, Serialize}; +use rocket::{async_stream, async_trait, futures::StreamExt}; use crate::{ - db::models::{ChatRsMessage, ChatRsMessageRole, ChatRsToolCall}, + db::models::ChatRsMessage, provider::{ - LlmApiProvider, LlmApiProviderSharedOptions, LlmApiStream, LlmError, LlmStreamChunk, - LlmTool, LlmUsage, DEFAULT_MAX_TOKENS, + utils::get_sse_events, LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmTool, + LlmUsage, DEFAULT_MAX_TOKENS, }, provider_models::{LlmModel, ModelsDevService, ModelsDevServiceProvider}, }; +use { + request::{ + build_anthropic_messages, build_anthropic_tools, AnthropicContentBlock, AnthropicMessage, + AnthropicRequest, + }, + response::{parse_anthropic_event, AnthropicResponse, AnthropicResponseContentBlock}, +}; + const MESSAGES_API_URL: &str = "https://api.anthropic.com/v1/messages"; const API_VERSION: &str = "2023-06-01"; /// Anthropic chat provider -pub struct AnthropicProvider<'a> { +#[derive(Debug, Clone)] +pub struct AnthropicProvider { client: reqwest::Client, - redis: &'a fred::prelude::Client, - api_key: &'a str, + redis: fred::clients::Client, + api_key: String, } -impl<'a> AnthropicProvider<'a> { +impl AnthropicProvider { pub fn new( http_client: &reqwest::Client, - redis: &'a fred::prelude::Client, - api_key: &'a str, + redis: &fred::clients::Client, + api_key: &str, ) -> Self { Self { client: http_client.clone(), - redis, - api_key, + redis: redis.clone(), + api_key: api_key.to_string(), } } - - fn build_messages( - &self, - messages: &'a [ChatRsMessage], - ) -> (Vec>, Option<&'a str>) { - let system_prompt = messages - .iter() - .rfind(|message| message.role == ChatRsMessageRole::System) - .map(|message| message.content.as_str()); - - let anthropic_messages: Vec = messages - .iter() - .filter_map(|message| { - let role = match message.role { - ChatRsMessageRole::User => "user", - ChatRsMessageRole::Tool => "user", - ChatRsMessageRole::Assistant => "assistant", - ChatRsMessageRole::System => return None, - }; - - let mut content_blocks = Vec::new(); - - // Handle tool result messages - if message.role == ChatRsMessageRole::Tool { - if let Some(executed_call) = &message.meta.tool_call { - content_blocks.push(AnthropicContentBlock::ToolResult { - tool_use_id: &executed_call.id, - content: &message.content, - }); - } - } else { - // Handle regular text content - if !message.content.is_empty() { - content_blocks.push(AnthropicContentBlock::Text { - text: &message.content, - }); - } - // Handle tool calls in assistant messages - if let Some(tool_calls) = message - .meta - .assistant - .as_ref() - .and_then(|a| a.tool_calls.as_ref()) - { - for tool_call in tool_calls { - content_blocks.push(AnthropicContentBlock::ToolUse { - id: &tool_call.id, - name: &tool_call.tool_name, - input: &tool_call.parameters, - }); - } - } - } - - if content_blocks.is_empty() { - return None; - } - - Some(AnthropicMessage { - role, - content: content_blocks, - }) - }) - .collect(); - - (anthropic_messages, system_prompt) - } - - fn build_tools(&self, tools: &'a [LlmTool]) -> Vec> { - tools - .iter() - .map(|tool| AnthropicTool { - name: &tool.name, - description: &tool.description, - input_schema: &tool.input_schema, - }) - .collect() - } - - async fn parse_sse_stream( - &self, - mut response: reqwest::Response, - tools: Option>, - ) -> LlmApiStream { - let stream = async_stream::stream! { - let mut buffer = String::new(); - let mut current_tool_calls: Vec> = Vec::new(); - - while let Some(chunk) = response.chunk().await.transpose() { - match chunk { - Ok(bytes) => { - let text = String::from_utf8_lossy(&bytes); - buffer.push_str(&text); - - while let Some(line_end_idx) = buffer.find('\n') { - let line = buffer[..line_end_idx].trim_end_matches('\r'); - - if line.starts_with("data: ") { - let data = &line[6..]; // Remove "data: " prefix - if data.trim().is_empty() || data == "[DONE]" { - buffer.drain(..=line_end_idx); - continue; - } - - match serde_json::from_str::(data) { - Ok(event) => { - match event { - AnthropicStreamEvent::MessageStart { message } => { - if let Some(usage) = message.usage { - yield Ok(LlmStreamChunk { - text: Some(String::new()), - tool_calls: None, - usage: Some(usage.into()), - }); - } - } - AnthropicStreamEvent::ContentBlockStart { content_block, index } => { - match content_block { - AnthropicResponseContentBlock::Text { text } => { - yield Ok(LlmStreamChunk { - text: Some(text), - tool_calls: None, - usage: None, - }); - } - AnthropicResponseContentBlock::ToolUse { id, name } => { - current_tool_calls.push(Some(AnthropicStreamToolCall { - id, - index, - name, - input: String::with_capacity(100), - })); - } - } - } - AnthropicStreamEvent::ContentBlockDelta { delta, index } => { - match delta { - AnthropicDelta::TextDelta { text } => { - yield Ok(LlmStreamChunk { - text: Some(text), - tool_calls: None, - usage: None, - }); - } - AnthropicDelta::InputJsonDelta { partial_json } => { - if let Some(Some(tool_call)) = current_tool_calls.iter_mut().find(|tc| tc.as_ref().is_some_and(|tc| tc.index == index)) { - tool_call.input.push_str(&partial_json); - } - } - } - } - AnthropicStreamEvent::ContentBlockStop { index } => { - if let Some(llm_tools) = &tools { - if !current_tool_calls.is_empty() { - let converted_call = current_tool_calls - .iter_mut() - .find(|tc| tc.as_ref().is_some_and(|tc| tc.index == index)) - .and_then(|tc| tc.take()) - .and_then(|tc| tc.convert(llm_tools)); - if let Some(converted_call) = converted_call { - yield Ok(LlmStreamChunk { - text: None, - tool_calls: Some(vec![converted_call]), - usage: None, - }); - } - } - } - } - AnthropicStreamEvent::MessageDelta { usage } => { - if let Some(usage) = usage { - yield Ok(LlmStreamChunk { - text: Some(String::new()), - tool_calls: None, - usage: Some(usage.into()), - }); - } - } - AnthropicStreamEvent::Error { error } => { - yield Err(LlmError::AnthropicError( - format!("{}: {}", error.error_type, error.message) - )); - } - _ => {} // Ignore other events (ping, message_stop) - } - } - Err(e) => { - rocket::warn!("Failed to parse SSE event: {} | Data: {}", e, data); - } - } - } else if line.starts_with("event: ") { - let event_type = &line[7..]; - rocket::debug!("SSE event type: {}", event_type); - } else if !line.trim().is_empty() && !line.starts_with(":") { - rocket::debug!("Unexpected SSE line: {}", line); - } - - buffer.drain(..=line_end_idx); - } - } - Err(e) => { - rocket::warn!("Stream chunk error: {}", e); - yield Err(LlmError::AnthropicError(format!("Stream error: {}", e))); - break; - } - } - } - - rocket::debug!("Anthropic SSE stream ended"); - }; - - Box::pin(stream) - } } #[async_trait] -impl<'a> LlmApiProvider for AnthropicProvider<'a> { +impl LlmApiProvider for AnthropicProvider { async fn chat_stream( &self, messages: Vec, tools: Option>, - options: &LlmApiProviderSharedOptions, - ) -> Result { - let (anthropic_messages, system_prompt) = self.build_messages(&messages); - let anthropic_tools = tools.as_ref().map(|t| self.build_tools(t)); - + options: &LlmProviderOptions, + ) -> Result { + let (anthropic_messages, system_prompt) = build_anthropic_messages(&messages); + let anthropic_tools = tools.as_ref().map(|t| build_anthropic_tools(t)); let request = AnthropicRequest { model: &options.model, messages: anthropic_messages, @@ -277,28 +72,42 @@ impl<'a> LlmApiProvider for AnthropicProvider<'a> { .post(MESSAGES_API_URL) .header("anthropic-version", API_VERSION) .header("content-type", "application/json") - .header("x-api-key", self.api_key) + .header("x-api-key", &self.api_key) .json(&request) .send() .await - .map_err(|e| LlmError::AnthropicError(format!("Request failed: {}", e)))?; - + .map_err(|e| LlmError::ProviderError(format!("Anthropic request failed: {}", e)))?; if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); - return Err(LlmError::AnthropicError(format!( - "API error {}: {}", + return Err(LlmError::ProviderError(format!( + "Anthropic API error {}: {}", status, error_text ))); } - Ok(self.parse_sse_stream(response, tools).await) + let stream = async_stream::stream! { + let mut sse_event_stream = get_sse_events(response); + let mut tool_calls = Vec::new(); + while let Some(event_result) = sse_event_stream.next().await { + match event_result { + Ok(event) => { + if let Some(chunk) = parse_anthropic_event(event, tools.as_ref(), &mut tool_calls) { + yield chunk; + } + }, + Err(e) => yield Err(e), + } + } + }; + + Ok(stream.boxed()) } async fn prompt( &self, message: &str, - options: &LlmApiProviderSharedOptions, + options: &LlmProviderOptions, ) -> Result { let request = AnthropicRequest { model: &options.model, @@ -318,35 +127,33 @@ impl<'a> LlmApiProvider for AnthropicProvider<'a> { .post(MESSAGES_API_URL) .header("anthropic-version", API_VERSION) .header("content-type", "application/json") - .header("x-api-key", self.api_key) + .header("x-api-key", &self.api_key) .json(&request) .send() .await - .map_err(|e| LlmError::AnthropicError(format!("Request failed: {}", e)))?; + .map_err(|e| LlmError::ProviderError(format!("Anthropic request failed: {}", e)))?; if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); - return Err(LlmError::AnthropicError(format!( - "API error {}: {}", + return Err(LlmError::ProviderError(format!( + "Anthropic API error {}: {}", status, error_text ))); } - let anthropic_response: AnthropicResponse = response + let mut anthropic_response: AnthropicResponse = response .json() .await - .map_err(|e| LlmError::AnthropicError(format!("Failed to parse response: {}", e)))?; - + .map_err(|e| LlmError::ProviderError(format!("Failed to parse response: {}", e)))?; let text = anthropic_response .content - .first() + .get_mut(0) .and_then(|block| match block { - AnthropicResponseContentBlock::Text { text } => Some(text.clone()), + AnthropicResponseContentBlock::Text { text } => Some(std::mem::take(text)), _ => None, }) .ok_or_else(|| LlmError::NoResponse)?; - if let Some(usage) = anthropic_response.usage { let usage: LlmUsage = usage.into(); println!("Prompt usage: {:?}", usage); @@ -356,7 +163,7 @@ impl<'a> LlmApiProvider for AnthropicProvider<'a> { } async fn list_models(&self) -> Result, LlmError> { - let models_service = ModelsDevService::new(self.redis.clone(), self.client.clone()); + let models_service = ModelsDevService::new(&self.redis, &self.client); let models = models_service .list_models(ModelsDevServiceProvider::Anthropic) .await?; @@ -364,179 +171,3 @@ impl<'a> LlmApiProvider for AnthropicProvider<'a> { Ok(models) } } - -/// Anthropic API request message -#[derive(Debug, Serialize)] -struct AnthropicMessage<'a> { - role: &'a str, - content: Vec>, -} - -/// Anthropic API request body -#[derive(Debug, Serialize)] -struct AnthropicRequest<'a> { - model: &'a str, - messages: Vec>, - max_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - system: Option<&'a str>, - #[serde(skip_serializing_if = "Option::is_none")] - stream: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>>, -} - -/// Anthropic tool definition -#[derive(Debug, Serialize)] -struct AnthropicTool<'a> { - name: &'a str, - description: &'a str, - input_schema: &'a serde_json::Value, -} - -/// Anthropic content block for messages -#[derive(Debug, Serialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum AnthropicContentBlock<'a> { - Text { - text: &'a str, - }, - ToolUse { - id: &'a str, - name: &'a str, - input: &'a HashMap, - }, - ToolResult { - tool_use_id: &'a str, - content: &'a str, - }, -} - -/// Anthropic API response content block -#[derive(Debug, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum AnthropicResponseContentBlock { - Text { text: String }, - ToolUse { id: String, name: String }, -} - -/// Anthropic API response usage -#[derive(Debug, Deserialize)] -struct AnthropicUsage { - input_tokens: Option, - output_tokens: Option, -} - -impl From for LlmUsage { - fn from(usage: AnthropicUsage) -> Self { - LlmUsage { - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - cost: None, - } - } -} - -/// Anthropic API response -#[derive(Debug, Deserialize)] -struct AnthropicResponse { - content: Vec, - usage: Option, -} - -/// Anthropic stream response (message start) -#[derive(Debug, Deserialize)] -struct AnthropicStreamResponse { - // id: String, - // #[serde(rename = "type")] - // message_type: String, - // role: String, - // content: Vec, - // model: String, - // stop_reason: Option, - // stop_sequence: Option, - usage: Option, -} - -/// Anthropic streaming event types -#[derive(Debug, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum AnthropicStreamEvent { - MessageStart { - message: AnthropicStreamResponse, - }, - ContentBlockStart { - index: usize, - content_block: AnthropicResponseContentBlock, - }, - ContentBlockDelta { - index: usize, - delta: AnthropicDelta, - }, - ContentBlockStop { - index: usize, - }, - MessageDelta { - // delta: AnthropicMessageDelta, - usage: Option, - }, - MessageStop, - Ping, - Error { - error: AnthropicError, - }, -} - -#[derive(Debug, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum AnthropicDelta { - TextDelta { text: String }, - InputJsonDelta { partial_json: String }, -} - -#[derive(Debug, Deserialize)] -struct AnthropicMessageDelta { - // stop_reason: Option, - // stop_sequence: Option, -} - -#[derive(Debug, Deserialize)] -struct AnthropicError { - #[serde(rename = "type")] - error_type: String, - message: String, -} - -/// Helper struct for tracking streaming tool calls -#[derive(Debug)] -struct AnthropicStreamToolCall { - id: String, - index: usize, - name: String, - /// Partial input parameters (JSON stringified) - input: String, -} - -impl AnthropicStreamToolCall { - /// Convert Anthropic tool call format to ChatRsToolCall - fn convert(self, llm_tools: &[LlmTool]) -> Option { - let input = if self.input.trim().is_empty() { - "{}" - } else { - &self.input - }; - let parameters = serde_json::from_str(input).ok()?; - llm_tools - .iter() - .find(|tool| tool.name == self.name) - .map(|tool| ChatRsToolCall { - id: self.id, - tool_id: tool.tool_id, - tool_name: self.name, - tool_type: tool.tool_type, - parameters, - }) - } -} diff --git a/server/src/provider/anthropic/request.rs b/server/src/provider/anthropic/request.rs new file mode 100644 index 0000000..43af62b --- /dev/null +++ b/server/src/provider/anthropic/request.rs @@ -0,0 +1,134 @@ +use std::collections::HashMap; + +use serde::Serialize; + +use crate::{ + db::models::{ChatRsMessage, ChatRsMessageRole}, + provider::LlmTool, +}; + +pub fn build_anthropic_messages<'a>( + messages: &'a [ChatRsMessage], +) -> (Vec>, Option<&'a str>) { + let system_prompt = messages + .iter() + .rfind(|message| message.role == ChatRsMessageRole::System) + .map(|message| message.content.as_str()); + + let anthropic_messages: Vec = messages + .iter() + .filter_map(|message| { + let role = match message.role { + ChatRsMessageRole::User => "user", + ChatRsMessageRole::Tool => "user", + ChatRsMessageRole::Assistant => "assistant", + ChatRsMessageRole::System => return None, + }; + + let mut content_blocks = Vec::new(); + + // Handle tool result messages + if message.role == ChatRsMessageRole::Tool { + if let Some(executed_call) = &message.meta.tool_call { + content_blocks.push(AnthropicContentBlock::ToolResult { + tool_use_id: &executed_call.id, + content: &message.content, + }); + } + } else { + // Handle regular text content + if !message.content.is_empty() { + content_blocks.push(AnthropicContentBlock::Text { + text: &message.content, + }); + } + // Handle tool calls in assistant messages + if let Some(tool_calls) = message + .meta + .assistant + .as_ref() + .and_then(|a| a.tool_calls.as_ref()) + { + for tool_call in tool_calls { + content_blocks.push(AnthropicContentBlock::ToolUse { + id: &tool_call.id, + name: &tool_call.tool_name, + input: &tool_call.parameters, + }); + } + } + } + + if content_blocks.is_empty() { + return None; + } + + Some(AnthropicMessage { + role, + content: content_blocks, + }) + }) + .collect(); + + (anthropic_messages, system_prompt) +} + +pub fn build_anthropic_tools<'a>(tools: &'a [LlmTool]) -> Vec> { + tools + .iter() + .map(|tool| AnthropicTool { + name: &tool.name, + description: &tool.description, + input_schema: &tool.input_schema, + }) + .collect() +} + +/// Anthropic API request message +#[derive(Debug, Serialize)] +pub struct AnthropicMessage<'a> { + pub role: &'a str, + pub content: Vec>, +} + +/// Anthropic API request body +#[derive(Debug, Serialize)] +pub struct AnthropicRequest<'a> { + pub model: &'a str, + pub messages: Vec>, + pub max_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>>, +} + +/// Anthropic tool definition +#[derive(Debug, Serialize)] +pub struct AnthropicTool<'a> { + name: &'a str, + description: &'a str, + input_schema: &'a serde_json::Value, +} + +/// Anthropic content block for messages +#[derive(Debug, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum AnthropicContentBlock<'a> { + Text { + text: &'a str, + }, + ToolUse { + id: &'a str, + name: &'a str, + input: &'a HashMap, + }, + ToolResult { + tool_use_id: &'a str, + content: &'a str, + }, +} diff --git a/server/src/provider/anthropic/response.rs b/server/src/provider/anthropic/response.rs new file mode 100644 index 0000000..daf3e04 --- /dev/null +++ b/server/src/provider/anthropic/response.rs @@ -0,0 +1,209 @@ +use serde::Deserialize; + +use crate::{ + db::models::ChatRsToolCall, + provider::{ + LlmPendingToolCall, LlmStreamChunk, LlmStreamChunkResult, LlmStreamError, LlmTool, LlmUsage, + }, +}; + +/// Parse an Anthropic SSE event. +pub fn parse_anthropic_event( + event: AnthropicStreamEvent, + tools: Option<&Vec>, + tool_calls: &mut Vec, +) -> Option { + match event { + AnthropicStreamEvent::MessageStart { message } => { + if let Some(usage) = message.usage { + return Some(Ok(LlmStreamChunk::Usage(usage.into()))); + } + } + AnthropicStreamEvent::ContentBlockStart { + content_block, + index, + } => match content_block { + AnthropicResponseContentBlock::Text { text } => { + return Some(Ok(LlmStreamChunk::Text(text))); + } + AnthropicResponseContentBlock::ToolUse { id, name } => { + tool_calls.push(AnthropicStreamToolCall { + id, + index, + name, + input: String::with_capacity(100), + }); + } + }, + AnthropicStreamEvent::ContentBlockDelta { delta, index } => match delta { + AnthropicDelta::TextDelta { text } => { + return Some(Ok(LlmStreamChunk::Text(text))); + } + AnthropicDelta::InputJsonDelta { partial_json } => { + if let Some(tool_call) = tool_calls.iter_mut().find(|tc| tc.index == index) { + tool_call.input.push_str(&partial_json); + let chunk = LlmStreamChunk::PendingToolCall(LlmPendingToolCall { + index, + tool_name: tool_call.name.clone(), + }); + return Some(Ok(chunk)); + } + } + }, + AnthropicStreamEvent::ContentBlockStop { index } => { + if let Some(llm_tools) = tools { + if let Some(tc) = tool_calls + .iter() + .position(|tc| tc.index == index) + .map(|i| tool_calls.swap_remove(i)) + { + if let Some(tool_call) = tc.convert(llm_tools) { + let chunk = LlmStreamChunk::ToolCalls(vec![tool_call]); + return Some(Ok(chunk)); + } + } + } + } + AnthropicStreamEvent::MessageDelta { usage } => { + if let Some(usage) = usage { + return Some(Ok(LlmStreamChunk::Usage(usage.into()))); + } + } + AnthropicStreamEvent::Error { error } => { + let error_msg = format!("{}: {}", error.error_type, error.message); + return Some(Err(LlmStreamError::ProviderError(error_msg))); + } + _ => {} // Ignore other events (ping, message_stop) + } + None +} + +/// Anthropic API response content block +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum AnthropicResponseContentBlock { + Text { text: String }, + ToolUse { id: String, name: String }, +} + +/// Anthropic API response usage +#[derive(Debug, Deserialize)] +pub struct AnthropicUsage { + input_tokens: Option, + output_tokens: Option, +} + +impl From for LlmUsage { + fn from(usage: AnthropicUsage) -> Self { + LlmUsage { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + cost: None, + } + } +} + +/// Anthropic API response +#[derive(Debug, Deserialize)] +pub struct AnthropicResponse { + pub content: Vec, + pub usage: Option, +} + +/// Anthropic stream response (message start) +#[derive(Debug, Deserialize)] +pub struct AnthropicStreamResponse { + // id: String, + // #[serde(rename = "type")] + // message_type: String, + // role: String, + // content: Vec, + // model: String, + // stop_reason: Option, + // stop_sequence: Option, + usage: Option, +} + +/// Anthropic streaming event types +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum AnthropicStreamEvent { + MessageStart { + message: AnthropicStreamResponse, + }, + ContentBlockStart { + index: usize, + content_block: AnthropicResponseContentBlock, + }, + ContentBlockDelta { + index: usize, + delta: AnthropicDelta, + }, + ContentBlockStop { + index: usize, + }, + MessageDelta { + // delta: AnthropicMessageDelta, + usage: Option, + }, + MessageStop, + Ping, + Error { + error: AnthropicError, + }, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum AnthropicDelta { + TextDelta { text: String }, + InputJsonDelta { partial_json: String }, +} + +#[derive(Debug, Deserialize)] +pub struct AnthropicMessageDelta { + // stop_reason: Option, + // stop_sequence: Option, +} + +#[derive(Debug, Deserialize)] +pub struct AnthropicError { + #[serde(rename = "type")] + error_type: String, + message: String, +} + +/// Helper struct for tracking streaming tool calls +#[derive(Debug)] +pub struct AnthropicStreamToolCall { + /// Anthropic tool call ID + id: String, + /// Index of the tool call in the message + index: usize, + /// Name of the tool + name: String, + /// Partial input parameters (JSON stringified) + input: String, +} + +impl AnthropicStreamToolCall { + /// Convert Anthropic tool call format to ChatRsToolCall + fn convert(self, llm_tools: &[LlmTool]) -> Option { + let input = if self.input.trim().is_empty() { + "{}" + } else { + &self.input + }; + let parameters = serde_json::from_str(input).ok()?; + llm_tools + .iter() + .find(|tool| tool.name == self.name) + .map(|tool| ChatRsToolCall { + id: self.id, + tool_id: tool.tool_id, + tool_name: self.name, + tool_type: tool.tool_type, + parameters, + }) + } +} diff --git a/server/src/provider/lorem.rs b/server/src/provider/lorem.rs index 0d87efd..4c13aa3 100644 --- a/server/src/provider/lorem.rs +++ b/server/src/provider/lorem.rs @@ -10,18 +10,19 @@ use tokio::time::{interval, Interval}; use crate::{ db::models::ChatRsMessage, provider::{ - LlmApiProvider, LlmApiProviderSharedOptions, LlmApiStream, LlmError, LlmStreamChunk, - LlmTool, + LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmStreamChunk, + LlmStreamChunkResult, LlmStreamError, LlmTool, }, provider_models::LlmModel, }; -/// A test/dummy provider that streams 'lorem ipsum...' +/// A test/dummy provider that streams 'lorem ipsum...' and emits test errors during the stream +#[derive(Debug, Clone)] pub struct LoremProvider { pub config: LoremConfig, } -#[derive(JsonSchema)] +#[derive(Debug, Clone, JsonSchema)] pub struct LoremConfig { pub interval: u32, } @@ -40,7 +41,7 @@ struct LoremStream { interval: Interval, } impl Stream for LoremStream { - type Item = Result; + type Item = LlmStreamChunkResult; fn poll_next( mut self: Pin<&mut Self>, @@ -55,13 +56,11 @@ impl Stream for LoremStream { let word = self.words[self.index]; self.index += 1; if self.index == 0 || self.index % 10 != 0 { - std::task::Poll::Ready(Some(Ok(LlmStreamChunk { - text: Some(word.to_owned()), - tool_calls: None, - usage: None, - }))) + std::task::Poll::Ready(Some(Ok(LlmStreamChunk::Text(word.to_owned())))) } else { - std::task::Poll::Ready(Some(Err(LlmError::LoremError("Test error")))) + std::task::Poll::Ready(Some(Err(LlmStreamError::ProviderError( + "Test error".into(), + )))) } } std::task::Poll::Pending => std::task::Poll::Pending, @@ -75,8 +74,8 @@ impl LlmApiProvider for LoremProvider { &self, _messages: Vec, _tools: Option>, - _options: &LlmApiProviderSharedOptions, - ) -> Result { + _options: &LlmProviderOptions, + ) -> Result { let lorem_words = vec![ "Lorem ipsum ", "dolor sit ", @@ -106,7 +105,7 @@ impl LlmApiProvider for LoremProvider { "nulla pariatur.", ]; - let stream: LlmApiStream = Box::pin(LoremStream { + let stream: LlmStream = Box::pin(LoremStream { words: lorem_words, index: 0, interval: interval(Duration::from_millis(self.config.interval.into())), @@ -120,7 +119,7 @@ impl LlmApiProvider for LoremProvider { async fn prompt( &self, _request: &str, - _options: &LlmApiProviderSharedOptions, + _options: &LlmProviderOptions, ) -> Result { Ok("Lorem ipsum".to_string()) } diff --git a/server/src/provider/ollama.rs b/server/src/provider/ollama.rs new file mode 100644 index 0000000..68ef215 --- /dev/null +++ b/server/src/provider/ollama.rs @@ -0,0 +1,197 @@ +//! Ollama LLM provider + +mod request; +mod response; + +use rocket::{async_stream, async_trait, futures::StreamExt}; + +use crate::{ + db::models::ChatRsMessage, + provider::{ + utils::get_json_events, LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, + LlmStreamChunk, LlmTool, LlmUsage, + }, + provider_models::LlmModel, +}; + +use { + request::{ + build_ollama_messages, build_ollama_tools, OllamaChatRequest, OllamaCompletionRequest, + OllamaOptions, + }, + response::{ + parse_ollama_event, OllamaCompletionResponse, OllamaModelsResponse, OllamaToolCall, + }, +}; + +const CHAT_API_URL: &str = "/api/chat"; +const COMPLETION_API_URL: &str = "/api/generate"; +const MODELS_API_URL: &str = "/api/tags"; + +/// Ollama chat provider +#[derive(Debug, Clone)] +pub struct OllamaProvider { + client: reqwest::Client, + base_url: String, +} + +impl OllamaProvider { + pub fn new(http_client: &reqwest::Client, base_url: &str) -> Self { + Self { + client: http_client.clone(), + base_url: base_url.trim_end_matches('/').to_string(), + } + } +} + +#[async_trait] +impl LlmApiProvider for OllamaProvider { + async fn chat_stream( + &self, + messages: Vec, + tools: Option>, + options: &LlmProviderOptions, + ) -> Result { + let ollama_messages = build_ollama_messages(&messages); + let ollama_tools = tools.as_ref().map(|t| build_ollama_tools(t)); + let ollama_options = OllamaOptions { + temperature: options.temperature, + num_predict: options.max_tokens, + ..Default::default() + }; + let request = OllamaChatRequest { + model: &options.model, + messages: ollama_messages, + tools: ollama_tools, + stream: Some(true), + options: Some(ollama_options), + }; + + let response = self + .client + .post(format!("{}{}", self.base_url, CHAT_API_URL)) + .header("content-type", "application/json") + .json(&request) + .send() + .await + .map_err(|e| LlmError::ProviderError(format!("Ollama request failed: {}", e)))?; + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + return Err(LlmError::ProviderError(format!( + "Ollama API error {}: {}", + status, error_text + ))); + } + + let stream = async_stream::stream! { + let mut json_stream = get_json_events(response); + let mut tool_calls: Vec = Vec::new(); + while let Some(event) = json_stream.next().await { + match event { + Ok(event) => { + for chunk in parse_ollama_event(event, &mut tool_calls) { + yield chunk; + } + } + Err(e) => yield Err(e), + } + } + if !tool_calls.is_empty() { + if let Some(llm_tools) = tools { + let converted = tool_calls + .into_iter() + .filter_map(|tc| tc.function.convert(&llm_tools)) + .collect(); + yield Ok(LlmStreamChunk::ToolCalls(converted)); + } + } + }; + + Ok(stream.boxed()) + } + + async fn prompt( + &self, + message: &str, + options: &LlmProviderOptions, + ) -> Result { + let ollama_options = OllamaOptions { + temperature: options.temperature, + num_predict: options.max_tokens, + ..Default::default() + }; + let request = OllamaCompletionRequest { + model: &options.model, + prompt: message, + stream: Some(false), + options: Some(ollama_options), + }; + let response = self + .client + .post(format!("{}{}", self.base_url, COMPLETION_API_URL)) + .header("content-type", "application/json") + .json(&request) + .send() + .await + .map_err(|e| LlmError::ProviderError(format!("Ollama request failed: {}", e)))?; + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + return Err(LlmError::ProviderError(format!( + "Ollama API error {}: {}", + status, error_text + ))); + } + + let ollama_response: OllamaCompletionResponse = response + .json() + .await + .map_err(|e| LlmError::ProviderError(format!("Failed to parse response: {}", e)))?; + if let Some(usage) = Option::::from(&ollama_response) { + println!("Prompt usage: {:?}", usage); + } + if ollama_response.response.is_empty() { + return Err(LlmError::NoResponse); + } + + Ok(ollama_response.response) + } + + async fn list_models(&self) -> Result, LlmError> { + let response = self + .client + .get(format!("{}{}", self.base_url, MODELS_API_URL)) + .send() + .await + .map_err(|e| LlmError::ProviderError(format!("Ollama models request failed: {}", e)))?; + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + return Err(LlmError::ProviderError(format!( + "Ollama models API error {}: {}", + status, error_text + ))); + } + + let models_response: OllamaModelsResponse = response.json().await.map_err(|e| { + LlmError::ProviderError(format!("Failed to parse models response: {}", e)) + })?; + let models = models_response + .models + .into_iter() + .map(|model| LlmModel { + id: model.name.clone(), + name: model.name, + temperature: Some(true), + tool_call: Some(true), + modified_at: Some(model.modified_at), + format: Some(model.details.format), + family: Some(model.details.family), + ..Default::default() + }) + .collect(); + + Ok(models) + } +} diff --git a/server/src/provider/ollama/request.rs b/server/src/provider/ollama/request.rs new file mode 100644 index 0000000..520b41f --- /dev/null +++ b/server/src/provider/ollama/request.rs @@ -0,0 +1,166 @@ +//! Ollama API request structures + +use serde::Serialize; + +use crate::{ + db::models::{ChatRsMessage, ChatRsMessageRole}, + provider::LlmTool, + tools::ToolParameters, +}; + +/// Convert ChatRsMessages to Ollama messages +pub fn build_ollama_messages(messages: &[ChatRsMessage]) -> Vec { + messages + .iter() + .map(|msg| { + let role = match msg.role { + ChatRsMessageRole::User => "user", + ChatRsMessageRole::Assistant => "assistant", + ChatRsMessageRole::System => "system", + ChatRsMessageRole::Tool => "tool", + }; + + let mut ollama_msg = OllamaMessage { + role, + content: &msg.content, + tool_calls: None, + tool_name: None, + }; + + // Handle tool calls in assistant messages + if msg.role == ChatRsMessageRole::Assistant { + if let Some(msg_tool_calls) = msg + .meta + .assistant + .as_ref() + .and_then(|m| m.tool_calls.as_ref()) + { + let tool_calls = msg_tool_calls + .iter() + .map(|tc| OllamaToolCall { + function: OllamaToolFunction { + name: &tc.tool_name, + arguments: &tc.parameters, + }, + }) + .collect(); + ollama_msg.tool_calls = Some(tool_calls); + } + } + + // Handle tool messages (results from tool calls) + if msg.role == ChatRsMessageRole::Tool { + if let Some(ref tool_call) = msg.meta.tool_call { + ollama_msg.tool_name = Some(&tool_call.tool_name); + } + } + + ollama_msg + }) + .collect() +} + +/// Convert LlmTools to Ollama tools +pub fn build_ollama_tools(tools: &[LlmTool]) -> Vec { + tools + .iter() + .map(|tool| OllamaTool { + r#type: "function", + function: OllamaToolSpec { + name: &tool.name, + description: &tool.description, + parameters: &tool.input_schema, + }, + }) + .collect() +} + +/// Ollama chat request structure +#[derive(Debug, Serialize)] +pub struct OllamaChatRequest<'a> { + pub model: &'a str, + pub messages: Vec>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, +} + +/// Ollama completion request structure +#[derive(Debug, Serialize)] +pub struct OllamaCompletionRequest<'a> { + pub model: &'a str, + pub prompt: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, +} + +/// Ollama chat message +#[derive(Debug, Serialize)] +pub struct OllamaMessage<'a> { + pub role: &'a str, + pub content: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_name: Option<&'a str>, +} + +/// Ollama tool call in a message +#[derive(Debug, Serialize)] +pub struct OllamaToolCall<'a> { + pub function: OllamaToolFunction<'a>, +} + +/// Ollama tool function +#[derive(Debug, Serialize)] +pub struct OllamaToolFunction<'a> { + pub name: &'a str, + pub arguments: &'a ToolParameters, +} + +/// Ollama tool definition +#[derive(Debug, Serialize)] +pub struct OllamaTool<'a> { + pub r#type: &'a str, + pub function: OllamaToolSpec<'a>, +} + +/// Ollama tool specification +#[derive(Debug, Serialize)] +pub struct OllamaToolSpec<'a> { + pub name: &'a str, + pub description: &'a str, + pub parameters: &'a serde_json::Value, +} + +/// Ollama model options +#[derive(Debug, Default, Serialize)] +pub struct OllamaOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub num_predict: Option, // Ollama's equivalent to max_tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, +} + +impl Default for OllamaChatRequest<'_> { + fn default() -> Self { + Self { + model: "", + messages: Vec::new(), + tools: None, + stream: None, + options: None, + } + } +} diff --git a/server/src/provider/ollama/response.rs b/server/src/provider/ollama/response.rs new file mode 100644 index 0000000..285b61e --- /dev/null +++ b/server/src/provider/ollama/response.rs @@ -0,0 +1,186 @@ +//! Ollama API response structures + +use serde::Deserialize; + +use crate::{ + db::models::ChatRsToolCall, + provider::{LlmPendingToolCall, LlmStreamChunk, LlmStreamError, LlmTool, LlmUsage}, +}; + +/// Parse Ollama streaming event into LlmStreamChunks, and track tool calls +pub fn parse_ollama_event( + event: OllamaStreamResponse, + tool_calls: &mut Vec, +) -> Vec> { + let mut chunks = Vec::with_capacity(1); + // Handle final message with usage stats + if event.done { + if let Some(usage) = Option::::from(&event) { + chunks.push(Ok(LlmStreamChunk::Usage(usage))); + } + } + + // Handle tool calls in the message + if !event.message.tool_calls.is_empty() { + for (index, tc) in event.message.tool_calls.iter().enumerate() { + let tool_call = LlmPendingToolCall { + index, + tool_name: tc.function.name.clone(), + }; + chunks.push(Ok(LlmStreamChunk::PendingToolCall(tool_call))); + } + tool_calls.extend(event.message.tool_calls); + } + + // Handle text content + if !event.message.content.is_empty() { + chunks.push(Ok(LlmStreamChunk::Text(event.message.content))); + } + + chunks +} + +/// Ollama chat response (streaming) +#[derive(Debug, Deserialize)] +pub struct OllamaStreamResponse { + pub model: String, + pub created_at: String, + pub message: OllamaMessage, + pub done: bool, + #[serde(default)] + pub done_reason: Option, + #[serde(default)] + pub total_duration: Option, + #[serde(default)] + pub load_duration: Option, + #[serde(default)] + pub prompt_eval_count: Option, + #[serde(default)] + pub prompt_eval_duration: Option, + #[serde(default)] + pub eval_count: Option, + #[serde(default)] + pub eval_duration: Option, +} + +/// Ollama completion response (non-streaming) +#[derive(Debug, Deserialize)] +pub struct OllamaCompletionResponse { + pub model: String, + pub created_at: String, + pub response: String, + pub done: bool, + #[serde(default)] + pub done_reason: Option, + #[serde(default)] + pub total_duration: Option, + #[serde(default)] + pub load_duration: Option, + #[serde(default)] + pub prompt_eval_count: Option, + #[serde(default)] + pub prompt_eval_duration: Option, + #[serde(default)] + pub eval_count: Option, + #[serde(default)] + pub eval_duration: Option, +} + +/// Ollama message in response +#[derive(Debug, Deserialize)] +pub struct OllamaMessage { + pub role: String, + #[serde(default)] + pub content: String, + #[serde(default)] + pub tool_calls: Vec, +} + +/// Ollama tool call in response +#[derive(Debug, Deserialize)] +pub struct OllamaToolCall { + pub function: OllamaToolFunction, +} + +/// Ollama tool function in response +#[derive(Debug, Deserialize)] +pub struct OllamaToolFunction { + pub name: String, + pub arguments: serde_json::Value, +} + +impl OllamaToolFunction { + /// Convert to ChatRsToolCall if the tool exists in the provided tools + pub fn convert(self, tools: &[LlmTool]) -> Option { + let tool = tools.iter().find(|t| t.name == self.name)?; + let parameters = serde_json::from_value(self.arguments).ok()?; + + Some(ChatRsToolCall { + id: uuid::Uuid::new_v4().to_string(), + parameters, + tool_id: tool.tool_id, + tool_name: self.name, + tool_type: tool.tool_type, + }) + } +} + +/// Convert Ollama usage to LlmUsage +impl From<&OllamaCompletionResponse> for Option { + fn from(response: &OllamaCompletionResponse) -> Self { + if response.prompt_eval_count.is_some() || response.eval_count.is_some() { + Some(LlmUsage { + input_tokens: response.prompt_eval_count, + output_tokens: response.eval_count, + cost: None, + }) + } else { + None + } + } +} + +impl From<&OllamaStreamResponse> for Option { + fn from(response: &OllamaStreamResponse) -> Self { + if response.prompt_eval_count.is_some() || response.eval_count.is_some() { + Some(LlmUsage { + input_tokens: response.prompt_eval_count, + output_tokens: response.eval_count, + cost: None, + }) + } else { + None + } + } +} + +/// Ollama models list response +#[derive(Debug, Deserialize)] +pub struct OllamaModelsResponse { + pub models: Vec, +} + +/// Ollama model information +#[derive(Debug, Deserialize)] +pub struct OllamaModelInfo { + pub name: String, + // pub model: String, + pub modified_at: String, + // pub size: u64, + // pub digest: String, + pub details: OllamaModelDetails, +} + +/// Ollama model details +#[derive(Debug, Deserialize)] +pub struct OllamaModelDetails { + // #[serde(default)] + // pub parent_model: String, + pub format: String, + pub family: String, + // #[serde(default)] + // pub families: Vec, + // pub parameter_size: String, + // #[serde(default)] + // pub quantization_level: Option, +} diff --git a/server/src/provider/openai.rs b/server/src/provider/openai.rs index b8b0a3f..5153d45 100644 --- a/server/src/provider/openai.rs +++ b/server/src/provider/openai.rs @@ -1,209 +1,77 @@ //! OpenAI (and OpenAI compatible) LLM provider -use rocket::async_trait; -use serde::{Deserialize, Serialize}; +mod request; +mod response; + +use rocket::{async_stream, async_trait, futures::StreamExt}; use crate::{ - db::models::{ChatRsMessage, ChatRsMessageRole, ChatRsToolCall}, + db::models::ChatRsMessage, provider::{ - LlmApiProvider, LlmApiProviderSharedOptions, LlmApiStream, LlmError, LlmStreamChunk, - LlmTool, LlmUsage, + utils::get_sse_events, LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, + LlmStreamChunk, LlmTool, LlmUsage, }, provider_models::{LlmModel, ModelsDevService, ModelsDevServiceProvider}, }; +use { + request::{ + build_openai_messages, build_openai_tools, OpenAIMessage, OpenAIRequest, + OpenAIStreamOptions, + }, + response::{parse_openai_event, OpenAIResponse, OpenAIStreamToolCall}, +}; + const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1"; const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1"; /// OpenAI chat provider -pub struct OpenAIProvider<'a> { +#[derive(Debug, Clone)] +pub struct OpenAIProvider { client: reqwest::Client, - redis: &'a fred::prelude::Client, - api_key: &'a str, - base_url: &'a str, + redis: fred::clients::Client, + api_key: String, + base_url: String, } -impl<'a> OpenAIProvider<'a> { +impl OpenAIProvider { pub fn new( http_client: &reqwest::Client, - redis: &'a fred::prelude::Client, - api_key: &'a str, - base_url: Option<&'a str>, + redis: &fred::clients::Client, + api_key: &str, + base_url: Option<&str>, ) -> Self { Self { client: http_client.clone(), - redis, - api_key, - base_url: base_url.unwrap_or(OPENAI_API_BASE_URL), + redis: redis.clone(), + api_key: api_key.to_owned(), + base_url: base_url.unwrap_or(OPENAI_API_BASE_URL).to_owned(), } } - - fn build_messages(&self, messages: &'a [ChatRsMessage]) -> Vec> { - messages - .iter() - .map(|message| { - let role = match message.role { - ChatRsMessageRole::User => "user", - ChatRsMessageRole::Assistant => "assistant", - ChatRsMessageRole::System => "system", - ChatRsMessageRole::Tool => "tool", - }; - let openai_message = OpenAIMessage { - role, - content: Some(&message.content), - tool_call_id: message.meta.tool_call.as_ref().map(|tc| tc.id.as_str()), - tool_calls: message - .meta - .assistant - .as_ref() - .and_then(|meta| meta.tool_calls.as_ref()) - .map(|tc| { - tc.iter() - .map(|tc| OpenAIToolCall { - id: &tc.id, - tool_type: "function", - function: OpenAIToolCallFunction { - name: &tc.tool_name, - arguments: serde_json::to_string(&tc.parameters) - .unwrap_or_default(), - }, - }) - .collect() - }), - }; - - openai_message - }) - .collect() - } - - fn build_tools(&self, tools: &'a [LlmTool]) -> Vec { - tools - .iter() - .map(|tool| OpenAITool { - tool_type: "function", - function: OpenAIToolFunction { - name: &tool.name, - description: &tool.description, - parameters: &tool.input_schema, - strict: true, - }, - }) - .collect() - } - - async fn parse_sse_stream( - &self, - mut response: reqwest::Response, - tools: Option>, - ) -> LlmApiStream { - let stream = async_stream::stream! { - let mut buffer = String::new(); - let mut tool_calls: Vec = Vec::new(); - - while let Some(chunk) = response.chunk().await.transpose() { - match chunk { - Ok(bytes) => { - let text = String::from_utf8_lossy(&bytes); - buffer.push_str(&text); - - while let Some(line_end_idx) = buffer.find('\n') { - let line = buffer[..line_end_idx].trim_end_matches('\r'); - - if line.starts_with("data: ") { - let data = &line[6..]; // Remove "data: " prefix - if data.trim().is_empty() || data == "[DONE]" { - buffer.drain(..=line_end_idx); - continue; - } - - match serde_json::from_str::(data) { - Ok(mut response) => { - if let Some(choice) = response.choices.pop() { - if let Some(delta) = choice.delta { - if delta.content.is_some() { - yield Ok(LlmStreamChunk { - text: delta.content, - ..Default::default() - }); - } - - if let Some(tool_calls_delta) = delta.tool_calls { - for tool_call_delta in tool_calls_delta { - if let Some(tc) = tool_calls.iter_mut().find(|tc| tc.index == tool_call_delta.index) { - if let Some(function_arguments) = tool_call_delta.function.arguments { - *tc.function.arguments.get_or_insert_default() += &function_arguments; - } - } else { - tool_calls.push(tool_call_delta); - } - } - } - } - } - - // Yield usage information if available - if let Some(usage) = response.usage { - yield Ok(LlmStreamChunk { - usage: Some(usage.into()), - ..Default::default() - }); - } - } - Err(e) => { - rocket::warn!("Failed to parse SSE event: {} | Data: {}", e, data); - } - } - } else if line.starts_with("event: ") { - let event_type = &line[7..]; - rocket::debug!("SSE event type: {}", event_type); - } else if !line.trim().is_empty() && !line.starts_with(":") { - rocket::debug!("Unexpected SSE line: {}", line); - } - - buffer.drain(..=line_end_idx); - } - } - Err(e) => { - rocket::warn!("Stream chunk error: {}", e); - yield Err(LlmError::OpenAIError(format!("Stream error: {}", e))); - break; - } - } - } - - if let Some(rs_chat_tools) = tools { - if !tool_calls.is_empty() { - yield Ok(LlmStreamChunk { - tool_calls: Some(tool_calls.into_iter().filter_map(|tc| tc.convert(&rs_chat_tools)).collect()), - ..Default::default() - }); - } - } - - rocket::debug!("SSE stream ended"); - }; - - Box::pin(stream) - } } #[async_trait] -impl<'a> LlmApiProvider for OpenAIProvider<'a> { +impl LlmApiProvider for OpenAIProvider { async fn chat_stream( &self, messages: Vec, tools: Option>, - options: &LlmApiProviderSharedOptions, - ) -> Result { - let openai_messages = self.build_messages(&messages); - let openai_tools = tools.as_ref().map(|t| self.build_tools(t)); + options: &LlmProviderOptions, + ) -> Result { + let openai_messages = build_openai_messages(&messages); + let openai_tools = tools.as_ref().map(|t| build_openai_tools(t)); let request = OpenAIRequest { model: &options.model, messages: openai_messages, - max_tokens: options.max_tokens, + max_tokens: (options.max_tokens.is_some() && self.base_url != OPENAI_API_BASE_URL) + .then(|| options.max_tokens.expect("already checked for Some value")), + // OpenAI official API has deprecated `max_tokens` for `max_completion_tokens` + max_completion_tokens: (options.max_tokens.is_some() + && self.base_url == OPENAI_API_BASE_URL) + .then(|| options.max_tokens.expect("already checked for Some value")), temperature: options.temperature, + store: (self.base_url == OPENAI_API_BASE_URL).then_some(false), stream: Some(true), stream_options: Some(OpenAIStreamOptions { include_usage: true, @@ -219,24 +87,48 @@ impl<'a> LlmApiProvider for OpenAIProvider<'a> { .json(&request) .send() .await - .map_err(|e| LlmError::OpenAIError(format!("Request failed: {}", e)))?; + .map_err(|e| LlmError::ProviderError(format!("OpenAI request failed: {}", e)))?; if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); - return Err(LlmError::OpenAIError(format!( - "API error {}: {}", + return Err(LlmError::ProviderError(format!( + "OpenAI API error {}: {}", status, error_text ))); } - Ok(self.parse_sse_stream(response, tools).await) + let stream = async_stream::stream! { + let mut sse_event_stream = get_sse_events(response); + let mut tool_calls: Vec = Vec::new(); + while let Some(event) = sse_event_stream.next().await { + match event { + Ok(event) => { + for chunk in parse_openai_event(event, &mut tool_calls) { + yield chunk; + } + } + Err(e) => yield Err(e), + } + } + if !tool_calls.is_empty() { + if let Some(llm_tools) = tools { + let converted = tool_calls + .into_iter() + .filter_map(|tc| tc.convert(&llm_tools)) + .collect(); + yield Ok(LlmStreamChunk::ToolCalls(converted)); + } + } + }; + + Ok(stream.boxed()) } async fn prompt( &self, message: &str, - options: &LlmApiProviderSharedOptions, + options: &LlmProviderOptions, ) -> Result { let request = OpenAIRequest { model: &options.model, @@ -258,27 +150,27 @@ impl<'a> LlmApiProvider for OpenAIProvider<'a> { .json(&request) .send() .await - .map_err(|e| LlmError::OpenAIError(format!("Request failed: {}", e)))?; + .map_err(|e| LlmError::ProviderError(format!("OpenAI request failed: {}", e)))?; if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); - return Err(LlmError::OpenAIError(format!( - "API error {}: {}", + return Err(LlmError::ProviderError(format!( + "OpenAI API error {}: {}", status, error_text ))); } - let openai_response: OpenAIResponse = response + let mut openai_response: OpenAIResponse = response .json() .await - .map_err(|e| LlmError::OpenAIError(format!("Failed to parse response: {}", e)))?; + .map_err(|e| LlmError::ProviderError(format!("Failed to parse response: {}", e)))?; let text = openai_response .choices - .first() - .and_then(|choice| choice.message.as_ref()) - .and_then(|message| message.content.as_ref()) + .get_mut(0) + .and_then(|choice| choice.message.as_mut()) + .and_then(|message| message.content.take()) .ok_or(LlmError::NoResponse)?; if let Some(usage) = openai_response.usage { @@ -286,14 +178,14 @@ impl<'a> LlmApiProvider for OpenAIProvider<'a> { println!("Prompt usage: {:?}", usage); } - Ok(text.clone()) + Ok(text) } async fn list_models(&self) -> Result, LlmError> { - let models_service = ModelsDevService::new(self.redis.clone(), self.client.clone()); + let models_service = ModelsDevService::new(&self.redis, &self.client); let models = models_service .list_models({ - match self.base_url { + match self.base_url.as_str() { OPENROUTER_API_BASE_URL => ModelsDevServiceProvider::OpenRouter, _ => ModelsDevServiceProvider::OpenAI, } @@ -303,159 +195,3 @@ impl<'a> LlmApiProvider for OpenAIProvider<'a> { Ok(models) } } - -/// OpenAI API request body -#[derive(Debug, Default, Serialize)] -struct OpenAIRequest<'a> { - model: &'a str, - messages: Vec>, - max_tokens: Option, - temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - stream: Option, - #[serde(skip_serializing_if = "Option::is_none")] - stream_options: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>>, -} - -/// OpenAI API request message -#[derive(Debug, Default, Serialize)] -struct OpenAIMessage<'a> { - role: &'a str, - #[serde(skip_serializing_if = "Option::is_none")] - content: Option<&'a str>, - #[serde(skip_serializing_if = "Option::is_none")] - tool_call_id: Option<&'a str>, - #[serde(skip_serializing_if = "Option::is_none")] - tool_calls: Option>>, -} - -/// OpenAI tool definition -#[derive(Debug, Serialize)] -struct OpenAITool<'a> { - #[serde(rename = "type")] - tool_type: &'a str, - function: OpenAIToolFunction<'a>, -} - -/// OpenAI tool function definition -#[derive(Debug, Serialize)] -struct OpenAIToolFunction<'a> { - name: &'a str, - description: &'a str, - parameters: &'a serde_json::Value, - strict: bool, -} - -/// OpenAI tool call in messages -#[derive(Debug, Serialize)] -struct OpenAIToolCall<'a> { - id: &'a str, - #[serde(rename = "type")] - tool_type: &'a str, - function: OpenAIToolCallFunction<'a>, -} - -/// OpenAI tool call function in messages -#[derive(Debug, Serialize)] -struct OpenAIToolCallFunction<'a> { - name: &'a str, - arguments: String, -} - -/// OpenAI API request stream options -#[derive(Debug, Serialize)] -struct OpenAIStreamOptions { - include_usage: bool, -} - -/// OpenAI API response -#[derive(Debug, Deserialize)] -struct OpenAIResponse { - choices: Vec, - usage: Option, -} - -/// OpenAI API streaming response -#[derive(Debug, Deserialize)] -struct OpenAIStreamResponse { - choices: Vec, - usage: Option, -} - -/// OpenAI API response choice -#[derive(Debug, Deserialize)] -struct OpenAIChoice { - message: Option, - delta: Option, - // finish_reason: Option, -} - -/// OpenAI API response message -#[derive(Debug, Deserialize)] -struct OpenAIResponseMessage { - // role: String, - content: Option, -} - -/// OpenAI API streaming delta -#[derive(Debug, Deserialize)] -struct OpenAIResponseDelta { - // role: Option, - content: Option, - tool_calls: Option>, -} - -/// OpenAI streaming tool call -#[derive(Debug, Deserialize)] -struct OpenAIStreamToolCall { - id: Option, - index: u32, - function: OpenAIStreamToolCallFunction, -} - -impl OpenAIStreamToolCall { - /// Convert OpenAI tool call format to ChatRsToolCall, add tool ID - fn convert(self, rs_chat_tools: &[LlmTool]) -> Option { - let id = self.id?; - let tool_name = self.function.name?; - let parameters = serde_json::from_str(&self.function.arguments?).ok()?; - rs_chat_tools - .iter() - .find(|tool| tool.name == tool_name) - .map(|tool| ChatRsToolCall { - id, - tool_id: tool.tool_id, - tool_name, - tool_type: tool.tool_type, - parameters, - }) - } -} - -/// OpenAI streaming tool call function -#[derive(Debug, Deserialize)] -struct OpenAIStreamToolCallFunction { - name: Option, - arguments: Option, -} - -/// OpenAI API response usage -#[derive(Debug, Deserialize)] -struct OpenAIUsage { - prompt_tokens: Option, - completion_tokens: Option, - cost: Option, - // total_tokens: Option, -} - -impl From for LlmUsage { - fn from(usage: OpenAIUsage) -> Self { - LlmUsage { - input_tokens: usage.prompt_tokens, - output_tokens: usage.completion_tokens, - cost: usage.cost, - } - } -} diff --git a/server/src/provider/openai/request.rs b/server/src/provider/openai/request.rs new file mode 100644 index 0000000..9feafcf --- /dev/null +++ b/server/src/provider/openai/request.rs @@ -0,0 +1,131 @@ +use serde::Serialize; + +use crate::{ + db::models::{ChatRsMessage, ChatRsMessageRole}, + provider::LlmTool, +}; + +pub fn build_openai_messages<'a>(messages: &'a [ChatRsMessage]) -> Vec> { + messages + .iter() + .map(|message| { + let role = match message.role { + ChatRsMessageRole::User => "user", + ChatRsMessageRole::Assistant => "assistant", + ChatRsMessageRole::System => "system", + ChatRsMessageRole::Tool => "tool", + }; + let openai_message = OpenAIMessage { + role, + content: Some(&message.content), + tool_call_id: message.meta.tool_call.as_ref().map(|tc| tc.id.as_str()), + tool_calls: message + .meta + .assistant + .as_ref() + .and_then(|meta| meta.tool_calls.as_ref()) + .map(|tc| { + tc.iter() + .map(|tc| OpenAIToolCall { + id: &tc.id, + tool_type: "function", + function: OpenAIToolCallFunction { + name: &tc.tool_name, + arguments: serde_json::to_string(&tc.parameters) + .unwrap_or_default(), + }, + }) + .collect() + }), + }; + + openai_message + }) + .collect() +} + +pub fn build_openai_tools<'a>(tools: &'a [LlmTool]) -> Vec> { + tools + .iter() + .map(|tool| OpenAITool { + tool_type: "function", + function: OpenAIToolFunction { + name: &tool.name, + description: &tool.description, + parameters: &tool.input_schema, + strict: true, + }, + }) + .collect() +} + +/// OpenAI API request body +#[derive(Debug, Default, Serialize)] +pub struct OpenAIRequest<'a> { + pub model: &'a str, + pub messages: Vec>, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>>, +} + +/// OpenAI API request stream options +#[derive(Debug, Serialize)] +pub struct OpenAIStreamOptions { + pub include_usage: bool, +} + +/// OpenAI API request message +#[derive(Debug, Default, Serialize)] +pub struct OpenAIMessage<'a> { + pub role: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>>, +} + +/// OpenAI tool definition +#[derive(Debug, Serialize)] +pub struct OpenAITool<'a> { + #[serde(rename = "type")] + tool_type: &'a str, + function: OpenAIToolFunction<'a>, +} + +/// OpenAI tool function definition +#[derive(Debug, Serialize)] +pub struct OpenAIToolFunction<'a> { + name: &'a str, + description: &'a str, + parameters: &'a serde_json::Value, + strict: bool, +} + +/// OpenAI tool call in messages +#[derive(Debug, Serialize)] +pub struct OpenAIToolCall<'a> { + id: &'a str, + #[serde(rename = "type")] + tool_type: &'a str, + function: OpenAIToolCallFunction<'a>, +} + +/// OpenAI tool call function in messages +#[derive(Debug, Serialize)] +pub struct OpenAIToolCallFunction<'a> { + name: &'a str, + arguments: String, +} diff --git a/server/src/provider/openai/response.rs b/server/src/provider/openai/response.rs new file mode 100644 index 0000000..534d45d --- /dev/null +++ b/server/src/provider/openai/response.rs @@ -0,0 +1,142 @@ +use serde::Deserialize; + +use crate::{ + db::models::ChatRsToolCall, + provider::{LlmPendingToolCall, LlmStreamChunk, LlmStreamChunkResult, LlmTool, LlmUsage}, +}; + +/// Parse chunks from an OpenAI SSE event +pub fn parse_openai_event( + mut event: OpenAIStreamResponse, + tool_calls: &mut Vec, +) -> Vec { + let mut chunks = Vec::with_capacity(1); + if let Some(delta) = event.choices.pop().and_then(|c| c.delta) { + if let Some(text) = delta.content { + chunks.push(Ok(LlmStreamChunk::Text(text))); + } + if let Some(tool_calls_delta) = delta.tool_calls { + for tool_call_delta in tool_calls_delta { + if let Some(tc) = tool_calls + .iter_mut() + .find(|tc| tc.index == tool_call_delta.index) + { + if let Some(function_arguments) = tool_call_delta.function.arguments { + *tc.function.arguments.get_or_insert_default() += &function_arguments; + } + if let Some(ref tool_name) = tc.function.name { + let chunk = LlmStreamChunk::PendingToolCall(LlmPendingToolCall { + index: tool_call_delta.index, + tool_name: tool_name.clone(), + }); + chunks.push(Ok(chunk)); + } + } else { + if let Some(ref tool_name) = tool_call_delta.function.name { + let chunk = LlmStreamChunk::PendingToolCall(LlmPendingToolCall { + index: tool_call_delta.index, + tool_name: tool_name.clone(), + }); + chunks.push(Ok(chunk)); + } + tool_calls.push(tool_call_delta); + } + } + } + } + if let Some(usage) = event.usage { + chunks.push(Ok(LlmStreamChunk::Usage(usage.into()))); + } + + chunks +} + +/// OpenAI API response +#[derive(Debug, Deserialize)] +pub struct OpenAIResponse { + pub choices: Vec, + pub usage: Option, +} + +/// OpenAI API streaming response +#[derive(Debug, Deserialize)] +pub struct OpenAIStreamResponse { + choices: Vec, + usage: Option, +} + +/// OpenAI API response choice +#[derive(Debug, Deserialize)] +pub struct OpenAIChoice { + pub message: Option, + pub delta: Option, + // finish_reason: Option, +} + +/// OpenAI API response message +#[derive(Debug, Deserialize)] +pub struct OpenAIResponseMessage { + // role: String, + pub content: Option, +} + +/// OpenAI API streaming delta +#[derive(Debug, Deserialize)] +pub struct OpenAIResponseDelta { + // role: Option, + content: Option, + tool_calls: Option>, +} + +/// OpenAI streaming tool call +#[derive(Debug, Deserialize)] +pub struct OpenAIStreamToolCall { + id: Option, + index: usize, + function: OpenAIStreamToolCallFunction, +} + +impl OpenAIStreamToolCall { + /// Convert OpenAI tool call format to ChatRsToolCall, add tool ID + pub fn convert(self, rs_chat_tools: &[LlmTool]) -> Option { + let id = self.id?; + let tool_name = self.function.name?; + let parameters = serde_json::from_str(&self.function.arguments?).ok()?; + rs_chat_tools + .iter() + .find(|tool| tool.name == tool_name) + .map(|tool| ChatRsToolCall { + id, + tool_id: tool.tool_id, + tool_name, + tool_type: tool.tool_type, + parameters, + }) + } +} + +/// OpenAI streaming tool call function +#[derive(Debug, Deserialize)] +struct OpenAIStreamToolCallFunction { + name: Option, + arguments: Option, +} + +/// OpenAI API response usage +#[derive(Debug, Deserialize)] +pub struct OpenAIUsage { + prompt_tokens: Option, + completion_tokens: Option, + cost: Option, + // total_tokens: Option, +} + +impl From for LlmUsage { + fn from(usage: OpenAIUsage) -> Self { + LlmUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cost: usage.cost, + } + } +} diff --git a/server/src/provider/utils.rs b/server/src/provider/utils.rs new file mode 100644 index 0000000..1845ab1 --- /dev/null +++ b/server/src/provider/utils.rs @@ -0,0 +1,47 @@ +use rocket::futures::TryStreamExt; +use serde::de::DeserializeOwned; +use tokio_stream::{Stream, StreamExt}; +use tokio_util::{ + codec::{FramedRead, LinesCodec}, + io::StreamReader, +}; + +use crate::provider::LlmStreamError; + +/// Get a stream of deserialized events from a provider SSE stream. +pub fn get_sse_events( + response: reqwest::Response, +) -> impl Stream> { + let stream_reader = StreamReader::new(response.bytes_stream().map_err(std::io::Error::other)); + let line_reader = FramedRead::new(stream_reader, LinesCodec::new()); + + line_reader.filter_map(|line_result| { + match line_result { + Ok(line) => { + if line.len() >= 6 && line.as_bytes().starts_with(b"data: ") { + let data = &line[6..]; // Skip "data: " prefix + if data.trim_start().is_empty() || data == "[DONE]" { + None // Skip empty lines and termination markers + } else { + Some(serde_json::from_str::(data).map_err(LlmStreamError::Parsing)) + } + } else { + None // Ignore non-data lines + } + } + Err(e) => Some(Err(LlmStreamError::Decoding(e))), + } + }) +} + +/// Get a stream of deserialized events from a provider JSON stream, not SSE (e.g. Ollama uses this format). +pub fn get_json_events( + response: reqwest::Response, +) -> impl Stream> { + let stream_reader = StreamReader::new(response.bytes_stream().map_err(std::io::Error::other)); + let line_reader = FramedRead::new(stream_reader, LinesCodec::new()); + line_reader.map(|line_result| match line_result { + Ok(line) => serde_json::from_str::(&line).map_err(LlmStreamError::Parsing), + Err(e) => Err(LlmStreamError::Decoding(e)), + }) +} diff --git a/server/src/provider_models.rs b/server/src/provider_models.rs index 58e41f1..6ed5bc8 100644 --- a/server/src/provider_models.rs +++ b/server/src/provider_models.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use enum_iterator::{all, Sequence}; -use fred::prelude::*; +use fred::prelude::{HashesInterface, KeysInterface}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -13,29 +13,29 @@ const CACHE_TTL: i64 = 86400; // 1 day in seconds /// A model supported by the LLM provider #[derive(Debug, Default, Clone, JsonSchema, Serialize, Deserialize)] pub struct LlmModel { - id: String, - name: String, + pub id: String, + pub name: String, #[serde(skip_serializing_if = "Option::is_none")] - attachment: Option, + pub attachment: Option, #[serde(skip_serializing_if = "Option::is_none")] - reasoning: Option, + pub reasoning: Option, #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, + pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] - tool_call: Option, + pub tool_call: Option, #[serde(skip_serializing_if = "Option::is_none")] - release_date: Option, + pub release_date: Option, #[serde(skip_serializing_if = "Option::is_none")] - knowledge: Option, + pub knowledge: Option, #[serde(skip_serializing_if = "Option::is_none")] - modalities: Option, + pub modalities: Option, // Ollama fields #[serde(skip_serializing_if = "Option::is_none")] - modified_at: Option, + pub modified_at: Option, #[serde(skip_serializing_if = "Option::is_none")] - format: Option, + pub format: Option, #[serde(skip_serializing_if = "Option::is_none")] - family: Option, + pub family: Option, } #[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)] @@ -56,13 +56,16 @@ pub enum ModalityType { /// Service to fetch and cache LLM model list from https://models.dev pub struct ModelsDevService { - redis: Client, + redis: fred::clients::Client, http_client: reqwest::Client, } impl ModelsDevService { - pub fn new(redis: Client, http_client: reqwest::Client) -> Self { - Self { redis, http_client } + pub fn new(redis: &fred::clients::Client, http_client: &reqwest::Client) -> Self { + Self { + redis: redis.clone(), + http_client: http_client.clone(), + } } pub async fn list_models( diff --git a/server/src/redis.rs b/server/src/redis.rs index 6de13c0..978f51b 100644 --- a/server/src/redis.rs +++ b/server/src/redis.rs @@ -1,44 +1,31 @@ -use std::{ops::Deref, time::Duration}; +use std::{ops::Deref, sync::Arc, time::Duration}; -use fred::prelude::{Builder, Client, ClientLike, Config, Pool, TcpConfig}; +use deadpool::managed; +use fred::prelude::{Builder, Client, ClientLike, ReconnectPolicy, TcpConfig}; use rocket::{ + async_trait, fairing::AdHoc, http::Status, request::{FromRequest, Outcome}, + Request, }; use rocket_okapi::OpenApiFromRequest; +use tokio::sync::Mutex; use crate::config::get_app_config; -/// Redis connection, available as a request guard. When used as a request parameter, -/// it will retrieve a connection from the managed Redis pool. -#[derive(OpenApiFromRequest)] -pub struct RedisClient(Client); -impl Deref for RedisClient { - type Target = Client; - fn deref(&self) -> &Self::Target { - &self.0 - } -} +/// Default size of the static Redis pool. +const REDIS_POOL_SIZE: usize = 4; +/// Default maximum number of concurrent exclusive clients (e.g. max concurrent streams) +const MAX_EXCLUSIVE_CLIENTS: usize = 20; +/// Timeout for connecting and executing commands. +const CLIENT_TIMEOUT: Duration = Duration::from_secs(6); +/// Interval for checking idle exclusive clients. +const IDLE_TASK_INTERVAL: Duration = Duration::from_secs(30); +/// Shut down exclusive clients after this period of inactivity. +const IDLE_TIME: Duration = Duration::from_secs(60); -/// Retrieve a connection from the managed Redis pool. Responds with an -/// internal server error if Redis not initialized. -#[rocket::async_trait] -impl<'r> FromRequest<'r> for RedisClient { - type Error = String; - - async fn from_request(req: &'r rocket::Request<'_>) -> Outcome { - let Some(pool) = req.rocket().state::() else { - return Outcome::Error(( - Status::InternalServerError, - "Redis not initialized".to_owned(), - )); - }; - Outcome::Success(RedisClient(pool.next().clone())) - } -} - -/// Fairing that sets up and initializes the Redis database +/// Fairing that sets up and initializes the Redis connection pool. pub fn setup_redis() -> AdHoc { AdHoc::on_ignite("Redis", |rocket| async { rocket @@ -46,32 +33,174 @@ pub fn setup_redis() -> AdHoc { "Initialize Redis connection", |rocket| async { let app_config = get_app_config(&rocket); - let config = Config::from_url(&app_config.redis_url) + let config = fred::prelude::Config::from_url(&app_config.redis_url) .expect("RS_CHAT_REDIS_URL should be valid Redis URL"); - let pool = Builder::from_config(config) - .with_connection_config(|config| { - config.connection_timeout = Duration::from_secs(4); - config.tcp = TcpConfig { - nodelay: Some(true), - ..Default::default() - }; - }) - .build_pool(app_config.redis_pool.unwrap_or(2)) - .expect("Failed to build Redis pool"); - pool.init().await.expect("Failed to connect to Redis"); - - rocket.manage(pool) + + // Build and initialize the static Redis pool + let static_pool = + build_redis_pool(config, app_config.redis_pool.unwrap_or(REDIS_POOL_SIZE)) + .expect("Failed to build static Redis pool"); + static_pool.init().await.expect("Redis connection failed"); + + // Build and initialize the dynamic, exclusive Redis pool for long-running tasks + let exclusive_manager = ExclusiveClientManager::new(static_pool.clone()); + let exclusive_pool: ExclusiveClientPool = + managed::Pool::builder(exclusive_manager) + .max_size(app_config.max_streams.unwrap_or(MAX_EXCLUSIVE_CLIENTS)) + .runtime(deadpool::Runtime::Tokio1) + .create_timeout(Some(CLIENT_TIMEOUT)) + .recycle_timeout(Some(CLIENT_TIMEOUT)) + .wait_timeout(Some(CLIENT_TIMEOUT)) + .build() + .expect("Failed to build exclusive Redis pool"); + + // Spawn a task to periodically clean up idle exclusive clients + let idle_task_pool = exclusive_pool.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(IDLE_TASK_INTERVAL); + loop { + interval.tick().await; + idle_task_pool.retain(|_, metrics| metrics.last_used() < IDLE_TIME); + } + }); + + rocket.manage(static_pool).manage(exclusive_pool) }, )) .attach(AdHoc::on_shutdown("Shutdown Redis connection", |rocket| { Box::pin(async { - if let Some(pool) = rocket.state::() { - rocket::info!("Shutting down Redis connection"); + if let Some(pool) = rocket.state::() { + rocket::info!("Shutting down static Redis pool"); if let Err(err) = pool.quit().await { - rocket::error!("Failed to shutdown Redis: {}", err); + rocket::warn!("Failed to shutdown Redis: {}", err); + } + } + if let Some(exclusive_pool) = rocket.state::() { + rocket::info!("Shutting down exclusive Redis pool"); + for client in exclusive_pool.manager().clients.lock().await.iter() { + if let Err(err) = client.quit().await { + rocket::warn!("Failed to shutdown Redis client: {}", err); + } } } }) })) }) } + +pub fn build_redis_pool( + redis_config: fred::prelude::Config, + pool_size: usize, +) -> Result { + Builder::from_config(redis_config) + .with_connection_config(|config| { + config.connection_timeout = CLIENT_TIMEOUT; + config.internal_command_timeout = CLIENT_TIMEOUT; + config.max_command_attempts = 2; + config.tcp = TcpConfig { + nodelay: Some(true), + ..Default::default() + }; + }) + .set_policy(ReconnectPolicy::new_linear(0, 10_000, 1000)) + .with_performance_config(|config| { + config.default_command_timeout = CLIENT_TIMEOUT; + }) + .build_pool(pool_size) +} + +/// Request guard for getting a Redis client from the shared static pool. +#[derive(Debug, OpenApiFromRequest)] +pub struct RedisClient(Client); +impl Deref for RedisClient { + type Target = Client; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +#[async_trait] +impl<'r> FromRequest<'r> for RedisClient { + type Error = (); + async fn from_request(req: &'r Request<'_>) -> Outcome { + use fred::clients::Pool; + let pool = req.rocket().state::().expect("Exists"); + Outcome::Success(RedisClient(pool.next().clone())) + } +} + +/// A pool of Redis clients with exclusive connections for long-running operations. Will +/// be stored in Rocket's managed state. +pub type ExclusiveClientPool = managed::Pool; + +/// Deadpool implementation for a pool of exclusive Redis clients. +#[derive(Debug)] +pub struct ExclusiveClientManager { + pool: fred::clients::Pool, + clients: Arc>>, +} +impl ExclusiveClientManager { + pub fn new(pool: fred::clients::Pool) -> Self { + Self { + pool, + clients: Arc::default(), + } + } +} +impl managed::Manager for ExclusiveClientManager { + type Type = Client; + type Error = fred::error::Error; + + async fn create(&self) -> Result { + let client = self.pool.next().clone_new(); + client.init().await?; + self.clients.lock().await.push(client.clone()); + Ok(client) + } + + async fn recycle( + &self, + client: &mut Client, + _: &managed::Metrics, + ) -> managed::RecycleResult { + if !client.is_connected() { + client.init().await?; + } + let _: () = client.ping(None).await?; + Ok(()) + } + + fn detach(&self, client: &mut Self::Type) { + let client = client.clone(); + let clients = self.clients.clone(); + tokio::spawn(async move { + clients.lock().await.retain(|c| c.id() != client.id()); + if let Err(err) = client.quit().await { + rocket::error!("Failed to disconnect Redis client: {}", err); + } + }); + } +} + +/// Request guard to get a Redis client with an exclusive connection for long-running operations. +#[derive(Debug, OpenApiFromRequest)] +pub struct ExclusiveRedisClient(pub managed::Object); +impl Deref for ExclusiveRedisClient { + type Target = Client; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +#[async_trait] +impl<'r> FromRequest<'r> for ExclusiveRedisClient { + type Error = (); + async fn from_request(req: &'r Request<'_>) -> Outcome { + let pool = req.rocket().state::().expect("Exists"); + match pool.get().await { + Ok(client) => Outcome::Success(ExclusiveRedisClient(client)), + Err(err) => { + rocket::error!("Failed to initialize Redis client: {}", err); + Outcome::Error((Status::InternalServerError, ())) + } + } + } +} diff --git a/server/src/stream.rs b/server/src/stream.rs new file mode 100644 index 0000000..65be5e6 --- /dev/null +++ b/server/src/stream.rs @@ -0,0 +1,92 @@ +mod llm_writer; +mod reader; + +use std::collections::HashMap; + +use fred::{ + prelude::{FredResult, KeysInterface, StreamsInterface}, + types::scan::ScanType, +}; + +pub use llm_writer::*; +pub use reader::*; + +use rocket::{ + async_trait, + http::Status, + request::{FromRequest, Outcome}, + Request, +}; +use rocket_okapi::OpenApiFromRequest; +use uuid::Uuid; + +/// Get the key prefix for the user's chat streams in Redis +fn get_chat_stream_prefix(user_id: &Uuid) -> String { + format!("user:{}:chat:", user_id) +} + +/// Get the key of the chat stream in Redis for the given user and session ID +fn get_chat_stream_key(user_id: &Uuid, session_id: &Uuid) -> String { + format!("{}{}", get_chat_stream_prefix(user_id), session_id) +} + +/// Get the ongoing chat stream sessions for a user. +pub async fn get_current_chat_streams( + redis: &fred::clients::Client, + user_id: &Uuid, +) -> FredResult> { + let prefix = get_chat_stream_prefix(user_id); + let pattern = format!("{}*", prefix); + let (_, keys): (String, Vec) = redis + .scan_page("0", &pattern, Some(20), Some(ScanType::Stream)) + .await?; + Ok(keys + .into_iter() + .filter_map(|key| Some(key.strip_prefix(&prefix)?.to_string())) + .collect()) +} + +/// Check if the chat stream exists. +pub async fn check_chat_stream_exists( + redis: &fred::clients::Client, + user_id: &Uuid, + session_id: &Uuid, +) -> FredResult { + let key = get_chat_stream_key(user_id, session_id); + let first_entry: Option<()> = redis.xread(Some(1), None, &key, "0-0").await?; + Ok(first_entry.is_some()) +} + +/// Cancel a stream by adding a `cancel` event to the stream and then deleting it from Redis +/// (not using a pipeline since we need to ensure the `cancel` event is processed before deleting the stream). +pub async fn cancel_current_chat_stream( + redis: &fred::clients::Client, + user_id: &Uuid, + session_id: &Uuid, +) -> FredResult<()> { + let key = get_chat_stream_key(user_id, session_id); + let entry: HashMap = RedisStreamChunk::Cancel.into(); + let _: () = redis.xadd(&key, true, None, "*", entry).await?; + redis.del(&key).await +} + +/// Request guard to extract the Last-Event-ID from the request headers +#[derive(OpenApiFromRequest)] +pub struct LastEventId(String); +impl std::ops::Deref for LastEventId { + type Target = str; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +#[async_trait] +impl<'r> FromRequest<'r> for LastEventId { + type Error = (); + + async fn from_request(req: &'r Request<'_>) -> Outcome { + match req.headers().get_one("Last-Event-ID") { + Some(event_id) => Outcome::Success(LastEventId(event_id.to_owned())), + None => Outcome::Error((Status::BadRequest, ())), + } + } +} diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs new file mode 100644 index 0000000..0c3f499 --- /dev/null +++ b/server/src/stream/llm_writer.rs @@ -0,0 +1,593 @@ +use std::{ + collections::HashMap, + time::{Duration, Instant}, +}; + +use fred::prelude::{FredResult, KeysInterface, StreamsInterface}; +use rocket::futures::StreamExt; +use serde::Serialize; +use uuid::Uuid; + +use crate::{ + db::models::ChatRsToolCall, + provider::{LlmPendingToolCall, LlmStream, LlmStreamChunk, LlmStreamError, LlmUsage}, + redis::ExclusiveRedisClient, + stream::get_chat_stream_key, +}; + +/// Interval at which chunks are flushed to the Redis stream. +const FLUSH_INTERVAL: Duration = Duration::from_millis(500); +/// Max accumulated size of the text chunk before it is automatically flushed to Redis. +const MAX_CHUNK_SIZE: usize = 200; +/// Expiration in seconds set on the Redis stream (normally, the Redis stream will be deleted before this) +const STREAM_EXPIRE: i64 = 30; +/// Timeout waiting for data from the LLM stream. +const LLM_TIMEOUT: Duration = Duration::from_secs(60); +/// Interval for sending ping messages to the Redis stream. +const PING_INTERVAL: Duration = Duration::from_secs(5); + +/// Utility for processing an incoming LLM response stream and writing to a Redis stream. +#[derive(Debug)] +pub struct LlmStreamWriter { + /// Redis client with an exclusive connection. + redis: ExclusiveRedisClient, + /// The key of the Redis stream. + key: String, + /// The current chunk of data being processed. + current_chunk: ChunkState, + /// Accumulated text response from the assistant. + complete_text: Option, + /// Accumulated tool calls from the assistant. + tool_calls: Option>, + /// Accumulated errors during the stream from the LLM provider. + errors: Option>, + /// Accumulated usage information from the LLM provider. + usage: Option, +} + +/// Internal state +#[derive(Debug, Default)] +struct ChunkState { + text: Option, + tool_calls: Option>, + pending_tool_calls: Option>, + error: Option, +} + +/// Chunk of the LLM response stored in the Redis stream. +#[derive(Debug, Serialize)] +#[serde(tag = "type", content = "data", rename_all = "snake_case")] +pub(super) enum RedisStreamChunk { + Start, + Ping, + Text(String), + ToolCall(String), + PendingToolCall(String), + Error(String), + Cancel, + End, +} +impl From for HashMap { + /// Converts a `RedisStreamChunk` into a hash map, suitable for the Redis client. + fn from(chunk: RedisStreamChunk) -> Self { + let value = serde_json::to_value(chunk).unwrap_or_default(); + serde_json::from_value(value).unwrap_or_default() + } +} + +impl LlmStreamWriter { + pub fn new(redis: ExclusiveRedisClient, user_id: &Uuid, session_id: &Uuid) -> Self { + LlmStreamWriter { + redis, + key: get_chat_stream_key(user_id, session_id), + current_chunk: ChunkState::default(), + complete_text: None, + tool_calls: None, + errors: None, + usage: None, + } + } + + /// Create the Redis stream and write a `start` entry. + pub async fn start(&self) -> FredResult<()> { + let entry: HashMap = RedisStreamChunk::Start.into(); + let pipeline = self.redis.pipeline(); + let _: () = pipeline.xadd(&self.key, false, None, "*", entry).await?; + let _: () = pipeline.expire(&self.key, STREAM_EXPIRE, None).await?; + pipeline.all().await + } + + /// Add an `end` event to notify clients that the stream has ended, and then + /// delete the stream from Redis. + pub async fn end(&self) -> FredResult<()> { + let entry: HashMap = RedisStreamChunk::End.into(); + let pipeline = self.redis.pipeline(); + let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; + let _: () = pipeline.del(&self.key).await?; + pipeline.all().await + } + + /// Process the incoming stream from the LLM provider, intermittently flushing + /// chunks to a Redis stream, and return the final accumulated response. + pub async fn process( + &mut self, + mut stream: LlmStream, + ) -> ( + Option, + Option>, + Option, + Option>, + bool, + ) { + let ping_handle = self.start_ping_task(); + + let mut last_flush_time = Instant::now(); + let mut cancelled = false; + loop { + match tokio::time::timeout(LLM_TIMEOUT, stream.next()).await { + Ok(Some(Ok(chunk))) => match chunk { + LlmStreamChunk::Text(text) => self.process_text(&text), + LlmStreamChunk::ToolCalls(tool_calls) => self.process_tool_calls(tool_calls), + LlmStreamChunk::PendingToolCall(pending_tool_call) => { + self.process_pending_tool_call(pending_tool_call) + } + LlmStreamChunk::Usage(usage) => self.process_usage(usage), + }, + Ok(Some(Err(err))) => self.process_error(err), + Ok(None) => { + // stream ended + self.flush_chunk().await.ok(); + break; + } + Err(_) => { + // timed out waiting for provider response + self.process_error(LlmStreamError::StreamTimeout); + self.flush_chunk().await.ok(); + break; + } + } + + if self.should_flush(&last_flush_time) { + if let Err(err) = self.flush_chunk().await { + if matches!(err, LlmStreamError::StreamCancelled) { + self.errors.get_or_insert_default().push(err); + cancelled = true; + break; + } + self.process_error(err); + } + last_flush_time = Instant::now(); + } + } + ping_handle.abort(); + + let complete_text = self.complete_text.take(); + let tool_calls = self.tool_calls.take(); + let usage = self.usage.take(); + let errors = self.errors.take().map(|e| { + e.into_iter() + .map(|e| e.to_string()) + .collect::>() + }); + (complete_text, tool_calls, usage, errors, cancelled) + } + + fn process_text(&mut self, text: &str) { + self.current_chunk + .text + .get_or_insert_with(|| String::with_capacity(MAX_CHUNK_SIZE)) + .push_str(text); + self.complete_text + .get_or_insert_with(|| String::with_capacity(1024)) + .push_str(text); + } + + fn process_tool_calls(&mut self, tool_calls: Vec) { + self.current_chunk + .tool_calls + .get_or_insert_default() + .extend(tool_calls.clone()); + self.tool_calls.get_or_insert_default().extend(tool_calls); + } + + fn process_pending_tool_call(&mut self, tool_call: LlmPendingToolCall) { + let current_chunk = self + .current_chunk + .pending_tool_calls + .get_or_insert_default(); + if !current_chunk.iter().any(|tc| tc.index == tool_call.index) { + current_chunk.push(tool_call); + } + } + + fn process_usage(&mut self, usage_chunk: LlmUsage) { + let usage = self.usage.get_or_insert_default(); + if let Some(input_tokens) = usage_chunk.input_tokens { + usage.input_tokens = Some(input_tokens); + } + if let Some(output_tokens) = usage_chunk.output_tokens { + usage.output_tokens = Some(output_tokens); + } + if let Some(cost) = usage_chunk.cost { + usage.cost = Some(cost); + } + } + + fn process_error(&mut self, err: LlmStreamError) { + self.current_chunk.error = Some(err.to_string()); + self.errors.get_or_insert_default().push(err); + } + + fn should_flush(&self, last_flush_time: &Instant) -> bool { + if self.current_chunk.tool_calls.is_some() || self.current_chunk.error.is_some() { + return true; + } + let text = self.current_chunk.text.as_ref(); + last_flush_time.elapsed() > FLUSH_INTERVAL || text.is_some_and(|t| t.len() > MAX_CHUNK_SIZE) + } + + /// Flushes the current chunk to the Redis stream. Returns a `LlmStreamError::StreamCancelled` error + /// if the stream has been deleted or cancelled. + async fn flush_chunk(&mut self) -> Result<(), LlmStreamError> { + let chunk_state = std::mem::take(&mut self.current_chunk); + + let mut chunks: Vec = Vec::with_capacity(2); + if let Some(text) = chunk_state.text { + chunks.push(RedisStreamChunk::Text(text)); + } + if let Some(tool_calls) = chunk_state.tool_calls { + chunks.extend(tool_calls.into_iter().map(|tc| { + RedisStreamChunk::ToolCall(serde_json::to_string(&tc).unwrap_or_default()) + })); + } + if let Some(pending_tool_calls) = chunk_state.pending_tool_calls { + chunks.extend(pending_tool_calls.into_iter().map(|tc| { + RedisStreamChunk::PendingToolCall(serde_json::to_string(&tc).unwrap_or_default()) + })); + } + if let Some(error) = chunk_state.error { + chunks.push(RedisStreamChunk::Error(error)); + } + if chunks.is_empty() { + return Ok(()); + } + + let entries = chunks.into_iter().map(|chunk| chunk.into()).collect(); + self.add_to_redis_stream(entries).await + } + + /// Adds new entries to the Redis stream. Returns a `LlmStreamError::StreamCancelled` error if the + /// stream has been deleted or cancelled. + async fn add_to_redis_stream( + &self, + entries: Vec>, + ) -> Result<(), LlmStreamError> { + let pipeline = self.redis.pipeline(); + for entry in entries { + let _: () = pipeline + .xadd(&self.key, true, ("MAXLEN", "~", 500), "*", entry) + .await?; + } + let res: Vec = pipeline.all().await?; + + // Check for `nil` responses indicating the stream has been deleted/cancelled + if res.iter().any(|r| r.is_null()) { + Err(LlmStreamError::StreamCancelled) + } else { + Ok(()) + } + } + + /// Start task that pings the Redis stream every `PING_INTERVAL` seconds and extends the expiration time + fn start_ping_task(&self) -> tokio::task::JoinHandle<()> { + let redis = self.redis.clone(); + let key = self.key.to_owned(); + let ping_handle = tokio::spawn(async move { + let mut interval = tokio::time::interval(PING_INTERVAL); + loop { + interval.tick().await; + let entry: HashMap = RedisStreamChunk::Ping.into(); + let pipeline = redis.pipeline(); + let _: FredResult<()> = pipeline.xadd(&key, true, None, "*", entry).await; + let _: FredResult<()> = pipeline.expire(&key, STREAM_EXPIRE, None).await; + let res: FredResult> = pipeline.all().await; + if res.is_err() || res.is_ok_and(|r| r.iter().any(|v| v.is_null())) { + break; + } + } + }); + ping_handle + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + provider::{lorem::LoremProvider, LlmApiProvider, LlmProviderOptions}, + redis::{ExclusiveClientManager, ExclusiveClientPool}, + stream::{cancel_current_chat_stream, check_chat_stream_exists}, + }; + use fred::prelude::{Builder, ClientLike, Config}; + use std::time::Duration; + + async fn setup_redis_pool() -> ExclusiveClientPool { + let config = + Config::from_url("redis://127.0.0.1:6379").unwrap_or_else(|_| Config::default()); + let pool = Builder::from_config(config) + .build_pool(1) + .expect("Failed to build Redis pool"); + pool.init().await.expect("Failed to connect to Redis"); + + let manager = ExclusiveClientManager::new(pool.clone()); + let deadpool: ExclusiveClientPool = deadpool::managed::Pool::builder(manager) + .max_size(3) + .build() + .unwrap(); + + deadpool + } + + async fn create_test_writer( + redis: &ExclusiveClientPool, + user_id: &Uuid, + session_id: &Uuid, + ) -> LlmStreamWriter { + let client = redis.get().await.expect("Failed to get Redis client"); + LlmStreamWriter::new(ExclusiveRedisClient(client), user_id, session_id) + } + + #[tokio::test] + async fn test_stream_writer_basic_functionality() { + let redis = setup_redis_pool().await; + let client = redis.get().await.unwrap(); + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let mut writer = create_test_writer(&redis, &user_id, &session_id).await; + + // Create stream + assert!(writer.start().await.is_ok()); + assert!(check_chat_stream_exists(&client, &user_id, &session_id) + .await + .unwrap()); + + // Create Lorem provider and get stream + let lorem = LoremProvider::new(); + let stream = lorem + .chat_stream(vec![], None, &LlmProviderOptions::default()) + .await + .expect("Failed to create lorem stream"); + + // Process the stream + let (text, tool_calls, usage, errors, cancelled) = writer.process(stream).await; + + // Verify results + assert!(text.is_some()); + let text = text.unwrap(); + assert!(!text.is_empty()); + assert!(text.contains("Lorem ipsum")); + assert!(text.contains("dolor sit")); + + assert!(tool_calls.is_none()); + assert!(usage.is_none()); + assert!(errors.is_some()); // Lorem provider generates some test errors + assert!(!cancelled); + + // End stream + assert!(writer.end().await.is_ok()); + + // Stream should be deleted after end + assert!(!check_chat_stream_exists(&client, &user_id, &session_id) + .await + .unwrap()); + } + + #[tokio::test] + async fn test_stream_writer_batching() { + let redis = setup_redis_pool().await; + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let mut writer = create_test_writer(&redis, &user_id, &session_id).await; + + assert!(writer.start().await.is_ok()); + + // Create a custom stream with small chunks to test batching + let chunks = vec![ + "Hello", " ", "world", "!", " ", "This", " ", "is", " ", "a", " ", "test", + ]; + let chunk_stream = tokio_stream::iter( + chunks + .into_iter() + .map(|text| Ok(LlmStreamChunk::Text(text.into()))), + ); + + let stream: LlmStream = Box::pin(chunk_stream); + let (text, _, _, _, cancelled) = writer.process(stream).await; + + assert!(text.is_some()); + let text = text.unwrap(); + assert_eq!(text, "Hello world! This is a test"); + assert!(!cancelled); + + writer.end().await.ok(); + } + + #[tokio::test] + async fn test_stream_writer_error_handling() { + let redis = setup_redis_pool().await; + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let mut writer = create_test_writer(&redis, &user_id, &session_id).await; + + assert!(writer.start().await.is_ok()); + + // Create a stream that produces an error + let error_stream = tokio_stream::iter(vec![ + Ok(LlmStreamChunk::Text("Hello".to_string())), + Err(LlmStreamError::ProviderError("Test error".into())), + Ok(LlmStreamChunk::Text(" World".to_string())), + ]); + + let stream: LlmStream = Box::pin(error_stream); + let (text, _, _, errors, cancelled) = writer.process(stream).await; + + assert!(text.is_some()); + let text = text.unwrap(); + assert_eq!(text, "Hello World"); + + assert!(errors.is_some()); + let errors = errors.unwrap(); + assert!(!errors.is_empty()); + assert!(errors.iter().any(|e| e.contains("Test error"))); + + assert!(!cancelled); + + writer.end().await.ok(); + } + + #[tokio::test] + async fn test_stream_writer_timeout() { + let redis = setup_redis_pool().await; + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let mut writer = create_test_writer(&redis, &user_id, &session_id).await; + + assert!(writer.start().await.is_ok()); + + // Create a stream that hangs (never yields anything) + let hanging_stream = tokio_stream::pending::>(); + + let stream: LlmStream = Box::pin(hanging_stream); + + // This should timeout due to LLM_TIMEOUT + let start = std::time::Instant::now(); + let (text, _, _, errors, cancelled) = writer.process(stream).await; + let elapsed = start.elapsed(); + + // Should complete in roughly LLM_TIMEOUT duration + assert!(elapsed >= Duration::from_secs(59)); // Allow some margin + assert!(elapsed < Duration::from_secs(65)); + + assert!(text.is_none()); + assert!(errors.is_some()); + let errors = errors.unwrap(); + assert!(errors.iter().any(|e| e.contains("Timeout"))); + assert!(!cancelled); // Timeout is not considered a cancellation + + writer.end().await.ok(); + } + + #[tokio::test] + async fn test_stream_writer_cancel() { + let redis = setup_redis_pool().await; + let client = redis.get().await.unwrap(); + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let writer = create_test_writer(&redis, &user_id, &session_id).await; + + assert!(writer.start().await.is_ok()); + assert!(check_chat_stream_exists(&client, &user_id, &session_id) + .await + .unwrap()); + + // Cancel the stream + assert!(cancel_current_chat_stream(&client, &user_id, &session_id) + .await + .is_ok()); + + // Stream should be deleted after cancel + assert!(!check_chat_stream_exists(&client, &user_id, &session_id) + .await + .unwrap()); + } + + #[tokio::test] + async fn test_stream_writer_usage_tracking() { + let redis = setup_redis_pool().await; + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let mut writer = create_test_writer(&redis, &user_id, &session_id).await; + + assert!(writer.start().await.is_ok()); + + // Create a stream with usage information + let usage_stream = tokio_stream::iter(vec![ + Ok(LlmStreamChunk::Text("Hello".into())), + Ok(LlmStreamChunk::Usage(LlmUsage { + input_tokens: Some(10), + output_tokens: Some(5), + cost: Some(0.001), + })), + Ok(LlmStreamChunk::Text(" World".into())), + Ok(LlmStreamChunk::Usage(LlmUsage { + input_tokens: None, // Should not override + output_tokens: Some(7), // Should update + cost: Some(0.002), // Should update + })), + ]); + + let stream: LlmStream = Box::pin(usage_stream); + let (text, _, usage, _, cancelled) = writer.process(stream).await; + + assert!(text.is_some()); + assert_eq!(text.unwrap(), "Hello World"); + + assert!(usage.is_some()); + let usage = usage.unwrap(); + assert_eq!(usage.input_tokens, Some(10)); + assert_eq!(usage.output_tokens, Some(7)); + assert_eq!(usage.cost, Some(0.002)); + + assert!(!cancelled); + + writer.end().await.ok(); + } + + #[tokio::test] + async fn test_redis_stream_entries() { + let redis = setup_redis_pool().await; + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let mut writer = create_test_writer(&redis, &user_id, &session_id).await; + let key = writer.key.clone(); + + assert!(writer.start().await.is_ok()); + + // Verify start event was written + let entries: Vec<(String, HashMap)> = redis + .get() + .await + .expect("Failed to get Redis connection") + .xrange(&key, "-", "+", None) + .await + .expect("Failed to read stream"); + + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].1.get("type"), Some(&"start".to_string())); + + // Create a simple stream + let simple_stream = tokio_stream::iter(vec![Ok(LlmStreamChunk::Text("Test chunk".into()))]); + let stream: LlmStream = Box::pin(simple_stream); + writer.process(stream).await; + writer.flush_chunk().await.ok(); + + // Should have start + text entries (ping task may add more) + let final_entries: Vec<(String, HashMap)> = redis + .get() + .await + .expect("Failed to get Redis connection") + .xrange(&key, "-", "+", None) + .await + .expect("Failed to read stream"); + + assert!(final_entries.len() >= 2); + + // Check that we have at least a text entry + let has_text = final_entries + .iter() + .any(|(_, data)| data.get("type") == Some(&"text".to_string())); + assert!(has_text); + + writer.end().await.ok(); + } +} diff --git a/server/src/stream/reader.rs b/server/src/stream/reader.rs new file mode 100644 index 0000000..c0eb870 --- /dev/null +++ b/server/src/stream/reader.rs @@ -0,0 +1,130 @@ +use std::collections::HashMap; + +use fred::prelude::StreamsInterface; +use rocket::response::stream::Event; +use tokio::sync::mpsc; +use uuid::Uuid; + +use crate::{provider::LlmError, redis::ExclusiveRedisClient, stream::get_chat_stream_key}; + +/// Timeout in milliseconds for the blocking `xread` command. +const XREAD_BLOCK_TIMEOUT: u64 = 5_000; // 5 seconds + +/// Utility for reading SSE events from a Redis stream. +pub struct SseStreamReader { + redis: ExclusiveRedisClient, +} + +impl SseStreamReader { + pub fn new(redis: ExclusiveRedisClient) -> Self { + Self { redis } + } + + /// Retrieve the previous events from the given Redis stream. + /// Returns a tuple containing the previous events, the last event ID, and a boolean + /// indicating if the stream has already ended. + pub async fn get_prev_events( + &self, + user_id: &Uuid, + session_id: &Uuid, + start_event_id: Option<&str>, + ) -> Result<(Vec, String, bool), LlmError> { + let key = get_chat_stream_key(user_id, session_id); + let start_event_id = start_event_id.unwrap_or("0-0"); + let (_, prev_events): (String, Vec<(String, HashMap)>) = self + .redis + .xread::>, _, _>(None, None, &key, start_event_id) + .await? + .and_then(|mut streams| streams.pop()) // should only be 1 stream since we're sending 1 key in the command + .ok_or(LlmError::StreamNotFound)?; + let (last_event_id, is_end) = prev_events + .last() + .map(|(id, data)| (id.to_owned(), data.get("type").is_some_and(|t| t == "end"))) + .unwrap_or_else(|| (start_event_id.into(), false)); + let sse_events = prev_events + .into_iter() + .map(convert_redis_event_to_sse) + .collect::>(); + + Ok((sse_events, last_event_id, is_end)) + } + + /// Stream the events from the given Redis stream using a blocking `xread` command. + pub async fn stream( + &self, + user_id: &Uuid, + session_id: &Uuid, + last_event_id: &str, + tx: &mpsc::Sender, + ) { + let key = get_chat_stream_key(user_id, session_id); + let mut last_event_id = last_event_id.to_owned(); + loop { + match self.get_next_event(&key, &mut last_event_id, tx).await { + Ok((id, data, is_end)) => { + let event = convert_redis_event_to_sse((id, data)); + if let Err(_) = tx.send(event).await { + break; // client disconnected + } + if is_end { + break; // reached end of stream + } + } + Err(err) => { + let event = Event::data(format!("Error: {}", err)).event("error"); + tx.send(event).await.ok(); + break; + } + } + } + } + + /// Get the next event from the given Redis stream using a blocking `xread` command. + /// - Updates the last event ID + /// - Cancels waiting for the next event if the client disconnects + /// - Returns the event ID, data, and a `bool` indicating whether it's an ending event + async fn get_next_event( + &self, + key: &str, + last_event_id: &mut String, + tx: &mpsc::Sender, + ) -> Result<(String, HashMap, bool), LlmError> { + let (_, mut events): (String, Vec<(String, HashMap)>) = tokio::select! { + res = self.redis.xread::>, _, _>(Some(1), Some(XREAD_BLOCK_TIMEOUT), key, &*last_event_id) => { + match res?.as_mut().and_then(|streams| streams.pop()) { + Some(stream) => stream, + None => return Err(LlmError::StreamNotFound), + } + }, + _ = tx.closed() => return Err(LlmError::ClientDisconnected) + }; + match events.pop() { + Some((id, data)) => { + *last_event_id = id.clone(); + let is_end = data + .get("type") + .is_some_and(|t| t == "end" || t == "cancel"); + Ok((id, data, is_end)) + } + None => Err(LlmError::NoStreamEvent), + } + } +} + +/// Convert a Redis stream event into an SSE event. Expects the event hash map to contain +/// a "type" and "data" field (e.g. serialized using the appropriate serde tag and content). +fn convert_redis_event_to_sse((id, event): (String, HashMap)) -> Event { + let mut r#type: Option = None; + let mut data: Option = None; + for (key, value) in event { + match key.as_str() { + "type" => r#type = Some(value), + "data" => data = Some(format!(" {value}")), // SSE spec: add space before data + _ => {} + } + } + + Event::data(data.unwrap_or_default()) + .event(r#type.unwrap_or_else(|| "unknown".into())) + .id(id) +} diff --git a/server/src/tools.rs b/server/src/tools.rs index e2d1d11..7d81cab 100644 --- a/server/src/tools.rs +++ b/server/src/tools.rs @@ -9,7 +9,11 @@ pub use { system::{ChatRsSystemToolConfig, SystemToolInput}, }; -use schemars::JsonSchema; +use { + crate::{db::services::ToolDbService, errors::ApiError, provider::LlmTool}, + schemars::JsonSchema, + uuid::Uuid, +}; /// User configuration of tools when sending a chat message #[derive(Debug, Default, PartialEq, JsonSchema, serde::Serialize, serde::Deserialize)] @@ -17,3 +21,28 @@ pub struct SendChatToolInput { pub system: Option, pub external_apis: Option>, } + +/// Get all tools from the user's input in LLM generic format +pub async fn get_llm_tools_from_input( + user_id: &Uuid, + input: &SendChatToolInput, + tool_db_service: &mut ToolDbService<'_>, +) -> Result, ApiError> { + let mut llm_tools = Vec::with_capacity(5); + if let Some(ref system_tool_input) = input.system { + let system_tools = tool_db_service.find_system_tools_by_user(&user_id).await?; + let system_llm_tools = system_tool_input.get_llm_tools(&system_tools)?; + llm_tools.extend(system_llm_tools); + } + if let Some(ref external_apis_input) = input.external_apis { + let external_api_tools = tool_db_service + .find_external_api_tools_by_user(&user_id) + .await?; + for tool_input in external_apis_input { + let api_llm_tools = tool_input.into_llm_tools(&external_api_tools)?; + llm_tools.extend(api_llm_tools); + } + } + + Ok(llm_tools) +} diff --git a/server/src/tools/system/system_info.rs b/server/src/tools/system/system_info.rs index 3cb5462..370c387 100644 --- a/server/src/tools/system/system_info.rs +++ b/server/src/tools/system/system_info.rs @@ -25,6 +25,7 @@ const SERVER_URL_DESC: &str = "Get the URL of the server that this chat applicat /// Tool to get system information. #[derive(Debug, JsonSchema)] +#[serde(deny_unknown_fields)] pub struct SystemInfo {} impl SystemInfo { pub fn new() -> Self { diff --git a/server/src/utils.rs b/server/src/utils.rs index 32b9c14..bd85450 100644 --- a/server/src/utils.rs +++ b/server/src/utils.rs @@ -3,11 +3,9 @@ mod full_text_search; mod generate_title; mod json_logging; mod sender_with_logging; -mod stored_stream; pub use encryption::*; pub use full_text_search::*; pub use generate_title::*; pub use json_logging::*; pub use sender_with_logging::*; -pub use stored_stream::*; diff --git a/server/src/utils/generate_title.rs b/server/src/utils/generate_title.rs index f61589a..a517f1f 100644 --- a/server/src/utils/generate_title.rs +++ b/server/src/utils/generate_title.rs @@ -1,87 +1,68 @@ +use std::ops::Deref; + use uuid::Uuid; use crate::{ - db::{ - models::{ChatRsProviderType, UpdateChatRsSession}, - services::ChatDbService, - DbConnection, DbPool, - }, - provider::{build_llm_provider_api, LlmApiProviderSharedOptions, DEFAULT_TEMPERATURE}, + db::{models::UpdateChatRsSession, services::ChatDbService, DbConnection, DbPool}, + errors::ApiError, + provider::{LlmApiProvider, LlmProviderOptions, DEFAULT_TEMPERATURE}, }; -const TITLE_PROMPT: &str = "This is the first message sent by a human in a session with an AI chatbot. Please generate a short title for the session (max 6 words) in plain text"; +const TITLE_TOKENS: u32 = 20; +const TITLE_PROMPT: &str = + "This is the first message sent by a human in a chat session with an AI chatbot. \ + Please generate a short title for the session (3-7 words) in plain text \ + (no quotes or prefixes)"; -/// Spawns a task to generate a title for the chat session +/// Spawn a task to generate a title for the chat session pub fn generate_title( user_id: &Uuid, session_id: &Uuid, user_message: &str, - provider_type: ChatRsProviderType, + provider: &Box, model: &str, - base_url: Option<&str>, - api_key: Option, - http_client: &reqwest::Client, - redis: &fred::prelude::Client, pool: &DbPool, ) { let user_id = user_id.to_owned(); let session_id = session_id.to_owned(); let user_message = user_message.to_owned(); + let provider = dyn_clone::clone_box(provider.deref()); let model = model.to_owned(); - let base_url = base_url.map(|url| url.to_owned()); - let http_client = http_client.clone(); - let redis = redis.clone(); let pool = pool.clone(); tokio::spawn(async move { - let Ok(conn) = pool.get().await else { - rocket::error!("Couldn't get database connection"); - return; - }; - let mut db = DbConnection(conn); - let Ok(provider) = build_llm_provider_api( - &provider_type, - base_url.as_deref(), - api_key.as_deref(), - &http_client, - &redis, - ) else { - rocket::warn!("Error creating provider for chat {}", session_id); - return; - }; - - let provider_options = LlmApiProviderSharedOptions { - model, - temperature: Some(DEFAULT_TEMPERATURE), - max_tokens: Some(20), - }; - let provider_response = provider - .prompt( - &format!("{}: \"{}\"", TITLE_PROMPT, user_message), - &provider_options, - ) - .await; - - match provider_response { - Ok(title) => { - rocket::info!("Generated title for chat {}", session_id); - if let Err(e) = ChatDbService::new(&mut db) - .update_session( - &user_id, - &session_id, - UpdateChatRsSession { - title: Some(title.trim()), - ..Default::default() - }, - ) - .await - { - rocket::warn!("Error saving title for chat {}: {}", session_id, e); - }; - } - Err(e) => { - rocket::warn!("Error generating title for chat {}: {}", session_id, e); - } + if let Err(err) = generate(user_id, session_id, user_message, provider, model, pool).await { + rocket::warn!("Failed to generate title: {}", err); } }); } + +async fn generate( + user_id: Uuid, + session_id: Uuid, + user_message: String, + provider: Box, + model: String, + pool: DbPool, +) -> Result<(), ApiError> { + let provider_options = LlmProviderOptions { + model, + temperature: Some(DEFAULT_TEMPERATURE), + max_tokens: Some(TITLE_TOKENS), + }; + let message = format!("{}: \"{}\"", TITLE_PROMPT, user_message); + let title = provider.prompt(&message, &provider_options).await?; + + let mut db = DbConnection(pool.get().await?); + ChatDbService::new(&mut db) + .update_session( + &user_id, + &session_id, + UpdateChatRsSession { + title: Some(title.trim()), + ..Default::default() + }, + ) + .await?; + Ok(()) +} diff --git a/server/src/utils/stored_stream.rs b/server/src/utils/stored_stream.rs deleted file mode 100644 index 35f54a1..0000000 --- a/server/src/utils/stored_stream.rs +++ /dev/null @@ -1,221 +0,0 @@ -use std::{ - pin::Pin, - task::{Context, Poll}, - time::{Duration, Instant}, -}; - -use fred::{ - prelude::{Client, KeysInterface}, - types::Expiration, -}; -use rocket::futures::Stream; -use uuid::Uuid; - -use crate::{ - db::{ - models::{ - AssistantMeta, ChatRsMessageMeta, ChatRsMessageRole, ChatRsToolCall, NewChatRsMessage, - }, - services::ChatDbService, - DbConnection, DbPool, - }, - provider::{LlmApiProviderSharedOptions, LlmUsage}, -}; - -/// A wrapper around the chat assistant stream that intermittently caches output in Redis, and -/// saves the assistant's response to the database at the end of the stream. -pub struct StoredChatRsStream< - S: Stream>, -> { - inner: Pin>, - provider_id: i32, - provider_options: Option, - redis_client: Client, - db_pool: DbPool, - session_id: Uuid, - buffer: Vec, - tool_calls: Option>, - input_tokens: u32, - output_tokens: u32, - cost: Option, - last_cache_time: Instant, -} - -/// The prefix for the cache key used for streaming chat responses. -pub const CHAT_CACHE_KEY_PREFIX: &str = "chat_session:"; -const CHAT_CACHE_INTERVAL: Duration = Duration::from_secs(1); // cache the response every second - -impl StoredChatRsStream -where - S: Stream>, -{ - pub fn new( - stream: S, - provider_id: i32, - provider_options: LlmApiProviderSharedOptions, - db_pool: DbPool, - redis_client: Client, - session_id: Option, - ) -> Self { - Self { - inner: Box::pin(stream), - provider_id, - provider_options: Some(provider_options), - db_pool, - redis_client, - session_id: session_id.unwrap_or_else(|| Uuid::new_v4()), - buffer: Vec::new(), - tool_calls: None, - input_tokens: 0, - output_tokens: 0, - cost: None, - last_cache_time: Instant::now(), - } - } - - pub fn session_id(&self) -> &Uuid { - &self.session_id - } - - fn save_response(&mut self, interrupted: Option) { - let redis_client = self.redis_client.clone(); - let pool = self.db_pool.clone(); - let session_id = self.session_id.clone(); - let provider_id = self.provider_id.clone(); - let provider_options = self.provider_options.take(); - let content = self.buffer.join(""); - let tool_calls = self.tool_calls.take(); - let usage = Some(LlmUsage { - input_tokens: Some(self.input_tokens), - output_tokens: Some(self.output_tokens), - cost: self.cost, - }); - self.buffer.clear(); - - tokio::spawn(async move { - let Ok(db) = pool.get().await else { - rocket::error!("Couldn't get connection while saving chat response"); - return; - }; - if let Err(e) = ChatDbService::new(&mut DbConnection(db)) - .save_message(NewChatRsMessage { - role: ChatRsMessageRole::Assistant, - content: &content, - session_id: &session_id, - meta: ChatRsMessageMeta { - assistant: Some(AssistantMeta { - provider_id, - provider_options, - partial: interrupted, - usage, - tool_calls, - }), - ..Default::default() - }, - }) - .await - { - rocket::error!("Failed saving chat response, session {}: {}", session_id, e); - } else { - rocket::info!("Saved chat response, session {}", session_id); - } - - let key = format!("{}{}", CHAT_CACHE_KEY_PREFIX, session_id); - let _ = redis_client.del::<(), _>(&key).await; - }); - } - - fn should_cache(&self) -> bool { - self.last_cache_time.elapsed() >= CHAT_CACHE_INTERVAL - } -} - -impl Stream for StoredChatRsStream -where - S: Stream>, -{ - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.inner.as_mut().poll_next(cx) { - Poll::Ready(Some(Ok(chunk))) => { - // Add text to buffer - if let Some(text) = &chunk.text { - self.buffer.push(text.clone()); - } - - // Record tool calls - if let Some(tool_calls) = chunk.tool_calls { - self.tool_calls.get_or_insert_default().extend(tool_calls); - } - - // Record usage - if let Some(usage) = chunk.usage { - if let Some(input_tokens) = usage.input_tokens { - self.input_tokens = input_tokens; - } - if let Some(output_tokens) = usage.output_tokens { - self.output_tokens = output_tokens; - } - if let Some(cost) = usage.cost { - self.cost = Some(cost); - } - } - - // Check if we should cache - if self.should_cache() { - let redis_client = self.redis_client.clone(); - let session_id = self.session_id.clone(); - let content = self.buffer.join(""); - - // Spawn async task to cache - tokio::spawn(async move { - let key = format!("{}{}", CHAT_CACHE_KEY_PREFIX, session_id); - rocket::debug!("Caching chat session {}", session_id); - if let Err(e) = redis_client - .set::<(), _, _>( - &key, - &content, - Some(Expiration::EX(3600)), - None, - false, - ) - .await - { - rocket::error!("Redis cache error: {}", e); - } - }); - - self.last_cache_time = Instant::now(); - } - - if let Some(text) = chunk.text { - Poll::Ready(Some(Ok(text))) - } else { - self.poll_next(cx) - } - } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(None) => { - // Stream ended, flush final buffer - if !self.buffer.is_empty() || self.tool_calls.is_some() { - self.save_response(None); - } - Poll::Ready(None) - } - Poll::Pending => Poll::Pending, - } - } -} - -impl Drop for StoredChatRsStream -where - S: Stream>, -{ - /// Stream was interrupted. Save response and mark as interrupted - fn drop(&mut self) { - if !self.buffer.is_empty() || self.tool_calls.is_some() { - self.save_response(Some(true)); - } - } -} diff --git a/web/package.json b/web/package.json index 9c710f0..bfffb6e 100644 --- a/web/package.json +++ b/web/package.json @@ -35,6 +35,7 @@ "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "cmdk": "^1.1.1", + "eventsource-parser": "^3.0.5", "highlight.svelte": "^0.1.3", "lowlight": "^3.3.0", "lucide-react": "^0.476.0", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index ff7e221..0e4906b 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -65,6 +65,9 @@ importers: cmdk: specifier: ^1.1.1 version: 1.1.1(@types/react-dom@19.1.6(@types/react@19.1.8))(@types/react@19.1.8)(react-dom@19.1.0(react@19.1.0))(react@19.1.0) + eventsource-parser: + specifier: ^3.0.5 + version: 3.0.5 highlight.svelte: specifier: ^0.1.3 version: 0.1.3 @@ -1613,6 +1616,10 @@ packages: estree-walker@3.0.3: resolution: {integrity: sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==} + eventsource-parser@3.0.5: + resolution: {integrity: sha512-bSRG85ZrMdmWtm7qkF9He9TNRzc/Bm99gEJMaQoHJ9E6Kv9QBbsldh2oMj7iXmYNEAVvNgvv5vPorG6W+XtBhQ==} + engines: {node: '>=20.0.0'} + expect-type@1.2.1: resolution: {integrity: sha512-/kP8CAwxzLVEeFrMm4kMmy4CCDlpipyA7MYLVrdJIkV0fYF0UaigQHRsxHiuY/GEea+bh4KSv3TIlgr+2UL6bw==} engines: {node: '>=12.0.0'} @@ -3965,6 +3972,8 @@ snapshots: dependencies: '@types/estree': 1.0.8 + eventsource-parser@3.0.5: {} + expect-type@1.2.1: {} extend@3.0.2: {} diff --git a/web/src/components/ProviderManager.tsx b/web/src/components/ProviderManager.tsx index 85d2139..5e96f7c 100644 --- a/web/src/components/ProviderManager.tsx +++ b/web/src/components/ProviderManager.tsx @@ -1,5 +1,5 @@ import { Bot, Plus, Trash2 } from "lucide-react"; -import { type FormEventHandler, useState } from "react"; +import { type FormEventHandler, useId, useState } from "react"; import { AlertDialog, @@ -39,13 +39,14 @@ import { import type { components } from "@/lib/api/types"; import { cn } from "@/lib/utils"; -type AIProvider = "anthropic" | "openai" | "openrouter" | "lorem"; +type AIProvider = "anthropic" | "openai" | "openrouter" | "ollama" | "lorem"; interface ProviderInfo { name: string; description: string; apiType: components["schemas"]["ChatRsProvider"]["provider_type"]; baseUrl?: string; + customBaseUrl?: boolean; keyFormat?: string; color: string; defaultModel: string; @@ -79,6 +80,16 @@ const PROVIDERS: Record = { color: "bg-blue-100 dark:bg-blue-900 border-blue-300 dark:border-blue-700", defaultModel: "openai/gpt-4o-mini", }, + ollama: { + name: "Ollama", + description: "Run LLMs locally with Ollama", + apiType: "ollama", + baseUrl: "http://localhost:11434", + customBaseUrl: true, + color: + "bg-indigo-100 dark:bg-indigo-900 border-indigo-300 dark:border-indigo-700", + defaultModel: "llama3.2", + }, lorem: { name: "Lorem Ipsum", description: "Generate placeholder text for testing", @@ -103,14 +114,22 @@ export function ProviderManager({ const [selectedProvider, setSelectedProvider] = useState( null, ); + const [name, setName] = useState(""); const [newApiKey, setNewApiKey] = useState(""); + const [newBaseUrl, setNewBaseUrl] = useState(""); + + const nameId = useId(); + const apiKeyId = useId(); + const baseUrlId = useId(); const handleCreateKey: FormEventHandler = (event) => { event.preventDefault(); if ( !selectedProvider || - (PROVIDERS[selectedProvider].apiType !== "lorem" && !newApiKey.trim()) + (PROVIDERS[selectedProvider].apiType !== "lorem" && + PROVIDERS[selectedProvider].apiType !== "ollama" && + !newApiKey.trim()) ) return; @@ -118,7 +137,7 @@ export function ProviderManager({ { name, type: PROVIDERS[selectedProvider].apiType, - base_url: PROVIDERS[selectedProvider].baseUrl, + base_url: newBaseUrl || PROVIDERS[selectedProvider].baseUrl, default_model: PROVIDERS[selectedProvider].defaultModel, api_key: newApiKey || null, }, @@ -127,6 +146,7 @@ export function ProviderManager({ setSelectedProvider(null); setIsCreateDialogOpen(false); setNewApiKey(""); + setNewBaseUrl(""); setName(""); }, }, @@ -204,19 +224,21 @@ export function ProviderManager({ {PROVIDERS[provider].description} -
- {PROVIDERS[provider].keyFormat} -
+ {PROVIDERS[provider].keyFormat && ( +
+ {PROVIDERS[provider].keyFormat} +
+ )} ))}
- + {selectedProvider && PROVIDERS[selectedProvider].keyFormat && (
- +
)} + {selectedProvider && + PROVIDERS[selectedProvider].customBaseUrl && ( +
+ + setNewBaseUrl(e.target.value)} + /> +
+ )}
+ )}