From 649728c173dec199a74505bd2ee6dccaf6ba06a4 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Wed, 20 Aug 2025 21:40:57 -0400 Subject: [PATCH 01/46] server: add util for sending LLM stream to Redis stream --- server/Cargo.toml | 5 +- server/src/db/models/tool.rs | 2 +- server/src/provider.rs | 2 +- server/src/utils.rs | 2 + server/src/utils/llm_stream.rs | 194 +++++++++++++++++++++++++++++++++ 5 files changed, 202 insertions(+), 3 deletions(-) create mode 100644 server/src/utils/llm_stream.rs diff --git a/server/Cargo.toml b/server/Cargo.toml index 1ae19b4..a6d39bc 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -34,7 +34,10 @@ diesel_as_jsonb = "1.0.1" diesel_async_migrations = "0.15.0" dotenvy = "0.15.7" enum-iterator = "2.1.0" -fred = { version = "10.1.0", default-features = false, features = ["i-keys"] } +fred = { version = "10.1.0", default-features = false, features = [ + "i-keys", + "i-streams", +] } hex = "0.4.3" jsonschema = { version = "0.30.0", default-features = false } rand = "0.9.1" diff --git a/server/src/db/models/tool.rs b/server/src/db/models/tool.rs index 19f899e..39b2d29 100644 --- a/server/src/db/models/tool.rs +++ b/server/src/db/models/tool.rs @@ -55,7 +55,7 @@ pub struct NewChatRsExternalApiTool<'r> { } /// A tool call requested by the provider -#[derive(Debug, JsonSchema, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, JsonSchema, serde::Serialize, serde::Deserialize)] pub struct ChatRsToolCall { /// ID of the tool call pub id: String, diff --git a/server/src/provider.rs b/server/src/provider.rs index 2befe4d..e911ca4 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -53,7 +53,7 @@ pub struct LlmStreamChunk { } /// Usage stats from the LLM provider -#[derive(Debug, JsonSchema, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] pub struct LlmUsage { pub input_tokens: Option, pub output_tokens: Option, diff --git a/server/src/utils.rs b/server/src/utils.rs index 32b9c14..46b6f3c 100644 --- a/server/src/utils.rs +++ b/server/src/utils.rs @@ -2,6 +2,7 @@ mod encryption; mod full_text_search; mod generate_title; mod json_logging; +mod llm_stream; mod sender_with_logging; mod stored_stream; @@ -9,5 +10,6 @@ pub use encryption::*; pub use full_text_search::*; pub use generate_title::*; pub use json_logging::*; +pub use llm_stream::*; pub use sender_with_logging::*; pub use stored_stream::*; diff --git a/server/src/utils/llm_stream.rs b/server/src/utils/llm_stream.rs new file mode 100644 index 0000000..4fd57d6 --- /dev/null +++ b/server/src/utils/llm_stream.rs @@ -0,0 +1,194 @@ +use std::time::{Duration, Instant}; + +use fred::prelude::StreamsInterface; +use rocket::futures::StreamExt; + +use crate::{ + db::models::ChatRsToolCall, + provider::{LlmApiStream, LlmError, LlmUsage}, +}; + +const MAX_CHUNK_SIZE: usize = 1000; +const MAX_FLUSH_TIME: Duration = Duration::from_millis(500); + +/// Utility struct for processing an incoming LLM stream and intermittently +/// flushing the data to a Redis stream. +#[derive(Debug, Default)] +pub struct LlmStreamProcessor { + redis: fred::prelude::Client, + /// The current chunk of data being processed. + current_chunk: RedisStreamChunkData, + /// Accumulated text response from the assistant. + complete_text: Option, + /// Accumulated tool calls from the assistant. + tool_calls: Option>, + /// Accumulated errors during the stream from the assistant. + errors: Option>, + /// Accumulated usage information from the assistant. + usage: Option, +} + +#[derive(Debug)] +enum RedisStreamChunk { + Data(RedisStreamChunkData), + End, +} + +#[derive(Debug, Default)] +struct RedisStreamChunkData { + text: Option, + tool_calls: Option>, + error: Option, +} + +impl LlmStreamProcessor { + pub fn new(redis: &fred::prelude::Client) -> Self { + LlmStreamProcessor { + redis: redis.clone(), + ..Default::default() + } + } + + /// Process the incoming stream from the LLM provider, intermittently + /// flush to Redis stream, and return the accumulated results. + pub async fn process_llm_stream( + mut self, + stream_key: &str, + mut stream: LlmApiStream, + ) -> ( + Option, + Option>, + Option, + Option>, + ) { + let mut last_flush_time = Instant::now(); + + while let Some(chunk) = stream.next().await { + match chunk { + Ok(chunk) => { + if let Some(ref text) = chunk.text { + self.process_text(text); + } + if let Some(tool_calls) = chunk.tool_calls { + self.process_tool_calls(tool_calls); + } + if let Some(usage_chunk) = chunk.usage { + self.process_usage(usage_chunk); + } + } + Err(err) => { + self.current_chunk.error = Some(err.to_string()); + self.errors.get_or_insert_default().push(err); + } + } + + if self.should_flush(&last_flush_time) { + self.flush_and_reset_chunk(&stream_key).await; + last_flush_time = Instant::now(); + } + } + + if let Err(e) = self.mark_end_of_redis_stream(&stream_key).await { + self.errors.get_or_insert_default().push(LlmError::Redis(e)); + }; + + (self.complete_text, self.tool_calls, self.usage, self.errors) + } + + fn process_text(&mut self, text: &str) { + self.current_chunk + .text + .get_or_insert_with(|| String::with_capacity(MAX_CHUNK_SIZE + 200)) + .push_str(text); + self.complete_text + .get_or_insert_with(|| String::with_capacity(2000)) + .push_str(text); + } + + fn process_tool_calls(&mut self, tool_calls: Vec) { + self.current_chunk + .tool_calls + .get_or_insert_default() + .extend(tool_calls.clone()); + self.tool_calls.get_or_insert_default().extend(tool_calls); + } + + fn process_usage(&mut self, usage_chunk: LlmUsage) { + let usage = self.usage.get_or_insert_default(); + if let Some(input_tokens) = usage_chunk.input_tokens { + usage.input_tokens = Some(input_tokens); + } + if let Some(output_tokens) = usage_chunk.output_tokens { + usage.output_tokens = Some(output_tokens); + } + if let Some(cost) = usage_chunk.cost { + usage.cost = Some(cost); + } + } + + fn should_flush(&self, last_flush_time: &Instant) -> bool { + // Flush if there are any tool calls or errors + if self.current_chunk.tool_calls.is_some() || self.current_chunk.error.is_some() { + return true; + } + // Skip flushing if chunk is completely empty + if self.current_chunk.text.is_none() { + return false; + } + // Check for time and size triggers + last_flush_time.elapsed() > MAX_FLUSH_TIME + || self + .current_chunk + .text + .as_ref() + .is_some_and(|t| t.len() > MAX_CHUNK_SIZE) + } + + async fn add_to_redis_stream( + &mut self, + stream_key: &str, + data: Vec<(&str, String)>, + ) -> Result<(), fred::prelude::Error> { + self.redis.xadd(stream_key, false, None, "*", data).await + } + + async fn flush_and_reset_chunk(&mut self, stream_key: &str) { + let chunk = std::mem::take(&mut self.current_chunk); + if let Ok(data) = RedisStreamChunk::Data(chunk).try_into() { + let _ = self.add_to_redis_stream(stream_key, data).await; + } + } + + async fn mark_end_of_redis_stream( + &mut self, + stream_key: &str, + ) -> Result<(), fred::prelude::Error> { + let data = RedisStreamChunk::End.try_into().expect("Should convert"); + self.add_to_redis_stream(stream_key, data).await + } +} + +impl TryFrom for Vec<(&str, String)> { + type Error = serde_json::Error; + + /// Converts a `RedisStreamChunk` into a vector of key-value pairs, suitable for the Redis client. + fn try_from(chunk: RedisStreamChunk) -> Result { + match chunk { + RedisStreamChunk::Data(data) => { + let mut vec = Vec::with_capacity(3); + vec.push(("type", "data".into())); + if let Some(text) = data.text { + vec.push(("text", text)); + } + if let Some(tool_calls) = data.tool_calls { + vec.push(("tool_calls", serde_json::to_string(&tool_calls)?)); + } + if let Some(error) = data.error { + vec.push(("error", error)); + } + Ok(vec) + } + RedisStreamChunk::End => Ok(vec![("type", "end".into())]), + } + } +} From afdf7564854f48891a9c23c8d071ac81acf4a8ac Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Thu, 21 Aug 2025 05:38:39 -0400 Subject: [PATCH 02/46] server: add new API for streaming --- server/src/api/chat.rs | 268 ++++++++++++++++++++++++++++++++- server/src/config.rs | 2 +- server/src/provider.rs | 4 + server/src/redis.rs | 2 +- server/src/utils/llm_stream.rs | 144 ++++++++++-------- 5 files changed, 348 insertions(+), 72 deletions(-) diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index 46a24ec..24c8422 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -1,5 +1,6 @@ -use std::{borrow::Cow, pin::Pin}; +use std::{borrow::Cow, collections::HashMap, pin::Pin}; +use fred::prelude::{KeysInterface, StreamsInterface}; use rocket::{ futures::{Stream, StreamExt}, post, @@ -11,6 +12,7 @@ use rocket_okapi::{ okapi::openapi3::OpenApi, openapi, openapi_get_routes_spec, settings::OpenApiSettings, }; use schemars::JsonSchema; +use tokio_stream::wrappers::ReceiverStream; use uuid::Uuid; use crate::{ @@ -18,21 +20,21 @@ use crate::{ auth::ChatRsUserId, db::{ models::{ - ChatRsMessageMeta, ChatRsMessageRole, ChatRsProviderType, ChatRsSessionMeta, - NewChatRsMessage, UpdateChatRsSession, + AssistantMeta, ChatRsMessageMeta, ChatRsMessageRole, ChatRsProviderType, + ChatRsSessionMeta, NewChatRsMessage, UpdateChatRsSession, }, services::{ChatDbService, ProviderDbService, ToolDbService}, DbConnection, DbPool, }, errors::ApiError, - provider::{build_llm_provider_api, LlmApiProviderSharedOptions, LlmTool}, + provider::{build_llm_provider_api, LlmApiProviderSharedOptions, LlmError, LlmTool}, redis::RedisClient, tools::SendChatToolInput, - utils::{generate_title, Encryptor, StoredChatRsStream}, + utils::{generate_title, Encryptor, LlmStreamProcessor, StoredChatRsStream}, }; pub fn get_routes(settings: &OpenApiSettings) -> (Vec, OpenApi) { - openapi_get_routes_spec![settings: send_chat_stream] + openapi_get_routes_spec![settings: send_chat_stream, send_chat_stream_v2, connect_to_chat_stream] } #[derive(JsonSchema, serde::Deserialize)] @@ -170,3 +172,257 @@ pub async fn send_chat_stream( .boxed(); Ok(EventStream::from(event_stream)) } + +/// # Start chat stream +/// Send a chat message and start streaming the response +#[openapi(tag = "Chat")] +#[post("//v2", data = "")] +pub async fn send_chat_stream_v2( + user_id: ChatRsUserId, + db_pool: &State, + mut db: DbConnection, + redis: RedisClient, + encryptor: &State, + http_client: &State, + session_id: Uuid, + mut input: Json>, +) -> Result { + // Check that we aren't already streaming a response for this session + let stream_key = format!("user:{}:chat:{}", user_id.0, session_id); + if redis.exists(&stream_key).await? { + return Err(LlmError::AlreadyStreaming)?; + } + + // Check session exists and user is owner, get message history + let (session, mut current_messages) = ChatDbService::new(&mut db) + .get_session_with_messages(&user_id, &session_id) + .await?; + + // Build the chat provider + let (provider, api_key_secret) = ProviderDbService::new(&mut db) + .get_by_id(&user_id, input.provider_id) + .await?; + let provider_type: ChatRsProviderType = provider.provider_type.as_str().try_into()?; + let api_key = api_key_secret + .map(|secret| encryptor.decrypt_string(&secret.ciphertext, &secret.nonce)) + .transpose()?; + let provider_api = build_llm_provider_api( + &provider_type, + provider.base_url.as_deref(), + api_key.as_deref(), + &http_client, + &redis, + )?; + + // Get the user's chosen tools + let mut llm_tools: Option> = None; + let mut tool_db_service = ToolDbService::new(&mut db); + if let Some(system_tool_input) = input.tools.as_ref().and_then(|t| t.system.as_ref()) { + let system_tools = tool_db_service.find_system_tools_by_user(&user_id).await?; + let system_llm_tools = system_tool_input.get_llm_tools(&system_tools)?; + llm_tools.get_or_insert_default().extend(system_llm_tools); + } + if let Some(external_apis_input) = input.tools.as_ref().and_then(|t| t.external_apis.as_ref()) { + let external_api_tools = tool_db_service + .find_external_api_tools_by_user(&user_id) + .await?; + for tool_input in external_apis_input { + let api_llm_tools = tool_input.into_llm_tools(&external_api_tools)?; + llm_tools.get_or_insert_default().extend(api_llm_tools); + } + } + + // Save user message and generate session title if needed + if let Some(user_message) = &input.message { + if current_messages.is_empty() && session.title == DEFAULT_SESSION_TITLE { + generate_title( + &user_id, + &session_id, + &user_message, + provider_type, + &provider.default_model, + provider.base_url.as_deref(), + api_key.clone(), + &http_client, + &redis, + db_pool, + ); + } + + let new_message = ChatDbService::new(&mut db) + .save_message(NewChatRsMessage { + content: user_message, + session_id: &session_id, + role: ChatRsMessageRole::User, + meta: ChatRsMessageMeta::default(), + }) + .await?; + current_messages.push(new_message); + } + + // Update session metadata if needed + if let Some(tool_input) = input.tools.take() { + if session + .meta + .tool_config + .is_none_or(|config| config != tool_input) + { + ChatDbService::new(&mut db) + .update_session( + &user_id, + &session_id, + UpdateChatRsSession { + meta: Some(&ChatRsSessionMeta { + tool_config: Some(tool_input), + }), + ..Default::default() + }, + ) + .await?; + } + } + + // Get the provider's stream response, and spawn a task to stream it to Redis + // and save the response to the database on completion + let stream = provider_api + .chat_stream(current_messages, llm_tools, &input.provider_options) + .await?; + let stream_processor = LlmStreamProcessor::new(&redis); + let provider_id = input.provider_id; + let provider_options = input.provider_options.clone(); + tokio::spawn(async move { + let (content, tool_calls, usage, _) = stream_processor + .process_llm_stream(&stream_key, stream) + .await; + if let Err(e) = ChatDbService::new(&mut db) + .save_message(NewChatRsMessage { + session_id: &session_id, + role: ChatRsMessageRole::Assistant, + content: &content.unwrap_or_default(), + meta: ChatRsMessageMeta { + assistant: Some(AssistantMeta { + provider_id, + provider_options: Some(provider_options), + tool_calls, + usage, + ..Default::default() + }), + ..Default::default() + }, + }) + .await + { + rocket::warn!("Failed to save assistant response: {}", e); + } + + // TODO delete stream in Redis + }); + + Ok("Stream started".into()) +} + +/// # Connect to chat stream +/// Connect to an ongoing chat stream and stream the assistant response +#[openapi(tag = "Chat")] +#[post("//stream")] +pub async fn connect_to_chat_stream( + user_id: ChatRsUserId, + redis: RedisClient, + session_id: Uuid, +) -> Result + Send>>>, ApiError> { + let stream_key = format!("user:{}:chat:{}", user_id.0, session_id); + + // Get all previous events from the Redis stream + let (_, prev_values): (String, Vec<(String, HashMap)>) = redis + .xread::>, _, _>(None, None, &stream_key, "0-0") + .await? + .ok_or(LlmError::StreamNotFound)? + .pop() + .ok_or(LlmError::StreamNotFound)?; + let last_event = prev_values.last().cloned(); + let prev_events_sse = prev_values + .into_iter() + .filter_map(convert_redis_event_to_sse); + let prev_events_stream = rocket::futures::stream::iter(prev_events_sse); + + // If `end` event already received, just return previous events + if let Some((_, ref data)) = last_event { + if data.get("type").is_some_and(|t| t == "end") { + return Ok(EventStream::from(prev_events_stream.boxed())); + } + } + + // Spawn a task to receive new events from Redis and add them to the channel + let (tx, rx) = tokio::sync::mpsc::channel::(50); + tokio::spawn(async move { + let mut last_event_id = last_event.map(|(id, _)| id).unwrap_or_else(|| "0-0".into()); + loop { + match get_next_event(&redis, &stream_key, &last_event_id, &tx).await { + Ok(Some((id, event))) => { + last_event_id = id; + if let Err(_) = tx.send(event).await { + break; // client disconnected, stop sending events + } + } + Ok(None) => { + tx.send(Event::empty().event("end")).await.ok(); + break; // No more events, end of stream + } + Err(e) => { + let event = Event::data(format!("Error: {}", e)).event("error"); + tx.send(event).await.ok(); + break; + } + } + } + drop(tx); + }); + + // Send stream of events from Redis to the client, starting with all previous events + let stream = prev_events_stream.chain(ReceiverStream::new(rx)).boxed(); + Ok(EventStream::from(stream)) +} + +async fn get_next_event( + redis: &RedisClient, + stream_key: &str, + last_event_id: &str, + tx: &tokio::sync::mpsc::Sender, +) -> Result, LlmError> { + let next_stream_value: Option)>)>> = tokio::select! { + next_value = redis.xread::<_, _, _>(Some(1), Some(5_000), stream_key, last_event_id) => next_value?, + _ = tx.closed() => { + println!("Client disconnected"); + return Ok(None); + }, + }; + match next_stream_value + .and_then(|mut v| v.pop()) + .and_then(|(_, mut events)| events.pop()) + { + Some((id, event)) => { + if event.get("type").is_some_and(|t| t == "end") { + return Ok(None); // handle `end` event + } + Ok(Some(id.clone()).zip(convert_redis_event_to_sse((id, event)))) + } + None => Ok(None), + } +} + +fn convert_redis_event_to_sse((id, event): (String, HashMap)) -> Option { + let mut r#type = None; + let mut data = None; + for (key, value) in event { + match key.as_str() { + "type" => r#type = Some(value), + "data" => data = Some(value), + _ => {} + } + } + if let Some(r#type) = r#type { + Some(Event::data(data.unwrap_or_default()).event(r#type).id(id)) + } else { + None + } +} diff --git a/server/src/config.rs b/server/src/config.rs index 18238ca..1476ba8 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -20,7 +20,7 @@ pub struct AppConfig { pub database_url: String, /// Redis connection URL pub redis_url: String, - /// Redis pool size (default: 2) + /// Redis pool size (default: 4) pub redis_pool: Option, } diff --git a/server/src/provider.rs b/server/src/provider.rs index e911ca4..7ee1235 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -36,6 +36,10 @@ pub enum LlmError { NoResponse, #[error("Unsupported provider")] UnsupportedProvider, + #[error("Already streaming a response for this session")] + AlreadyStreaming, + #[error("No ongoing stream for this session")] + StreamNotFound, #[error("Encryption error")] EncryptionError, #[error("Decryption error")] diff --git a/server/src/redis.rs b/server/src/redis.rs index 6de13c0..ecd1c07 100644 --- a/server/src/redis.rs +++ b/server/src/redis.rs @@ -56,7 +56,7 @@ pub fn setup_redis() -> AdHoc { ..Default::default() }; }) - .build_pool(app_config.redis_pool.unwrap_or(2)) + .build_pool(app_config.redis_pool.unwrap_or(4)) .expect("Failed to build Redis pool"); pool.init().await.expect("Failed to connect to Redis"); diff --git a/server/src/utils/llm_stream.rs b/server/src/utils/llm_stream.rs index 4fd57d6..ca896c4 100644 --- a/server/src/utils/llm_stream.rs +++ b/server/src/utils/llm_stream.rs @@ -1,7 +1,11 @@ -use std::time::{Duration, Instant}; +use std::{ + collections::HashMap, + time::{Duration, Instant}, +}; -use fred::prelude::StreamsInterface; +use fred::prelude::{KeysInterface, StreamsInterface}; use rocket::futures::StreamExt; +use serde::Serialize; use crate::{ db::models::ChatRsToolCall, @@ -9,7 +13,7 @@ use crate::{ }; const MAX_CHUNK_SIZE: usize = 1000; -const MAX_FLUSH_TIME: Duration = Duration::from_millis(500); +const FLUSH_INTERVAL: Duration = Duration::from_millis(500); /// Utility struct for processing an incoming LLM stream and intermittently /// flushing the data to a Redis stream. @@ -17,30 +21,42 @@ const MAX_FLUSH_TIME: Duration = Duration::from_millis(500); pub struct LlmStreamProcessor { redis: fred::prelude::Client, /// The current chunk of data being processed. - current_chunk: RedisStreamChunkData, + current_chunk: ChunkState, /// Accumulated text response from the assistant. complete_text: Option, /// Accumulated tool calls from the assistant. tool_calls: Option>, - /// Accumulated errors during the stream from the assistant. + /// Accumulated errors during the stream from the LLM provider. errors: Option>, - /// Accumulated usage information from the assistant. + /// Accumulated usage information from the LLM provider. usage: Option, } -#[derive(Debug)] -enum RedisStreamChunk { - Data(RedisStreamChunkData), - End, -} - #[derive(Debug, Default)] -struct RedisStreamChunkData { +struct ChunkState { text: Option, tool_calls: Option>, error: Option, } +/// Chunk of the LLM response stored in the Redis stream. +#[derive(Debug, Serialize)] +#[serde(tag = "type", content = "data", rename_all = "snake_case")] +enum RedisStreamChunk { + Start, + Text(String), + ToolCall(String), + Error(String), + End, +} +impl From for HashMap { + /// Converts a `RedisStreamChunk` into a hash map, suitable for the Redis client. + fn from(chunk: RedisStreamChunk) -> Self { + let value = serde_json::to_value(chunk).unwrap_or_default(); + serde_json::from_value(value).unwrap_or_default() + } +} + impl LlmStreamProcessor { pub fn new(redis: &fred::prelude::Client) -> Self { LlmStreamProcessor { @@ -61,8 +77,11 @@ impl LlmStreamProcessor { Option, Option>, ) { - let mut last_flush_time = Instant::now(); + if let Err(e) = self.notify_start_of_redis_stream(&stream_key).await { + self.errors.get_or_insert_default().push(LlmError::Redis(e)); + }; + let mut last_flush_time = Instant::now(); while let Some(chunk) = stream.next().await { match chunk { Ok(chunk) => { @@ -77,18 +96,17 @@ impl LlmStreamProcessor { } } Err(err) => { - self.current_chunk.error = Some(err.to_string()); - self.errors.get_or_insert_default().push(err); + self.process_error(err); } } if self.should_flush(&last_flush_time) { - self.flush_and_reset_chunk(&stream_key).await; + self.flush_and_reset(&stream_key).await; last_flush_time = Instant::now(); } } - if let Err(e) = self.mark_end_of_redis_stream(&stream_key).await { + if let Err(e) = self.notify_end_of_redis_stream(&stream_key).await { self.errors.get_or_insert_default().push(LlmError::Redis(e)); }; @@ -126,69 +144,67 @@ impl LlmStreamProcessor { } } + fn process_error(&mut self, err: LlmError) { + self.current_chunk.error = Some(err.to_string()); + self.errors.get_or_insert_default().push(err); + } + fn should_flush(&self, last_flush_time: &Instant) -> bool { - // Flush if there are any tool calls or errors if self.current_chunk.tool_calls.is_some() || self.current_chunk.error.is_some() { return true; } - // Skip flushing if chunk is completely empty - if self.current_chunk.text.is_none() { - return false; + if let Some(ref text) = self.current_chunk.text { + return text.len() > MAX_CHUNK_SIZE || last_flush_time.elapsed() > FLUSH_INTERVAL; } - // Check for time and size triggers - last_flush_time.elapsed() > MAX_FLUSH_TIME - || self - .current_chunk - .text - .as_ref() - .is_some_and(|t| t.len() > MAX_CHUNK_SIZE) + return false; } - async fn add_to_redis_stream( + async fn flush_and_reset(&mut self, stream_key: &str) { + let chunk_state = std::mem::take(&mut self.current_chunk); + + let mut chunks: Vec = Vec::with_capacity(2); + if let Some(text) = chunk_state.text { + chunks.push(RedisStreamChunk::Text(text)); + } + if let Some(tool_calls) = chunk_state.tool_calls { + chunks.extend(tool_calls.into_iter().map(|tc| { + RedisStreamChunk::ToolCall(serde_json::to_string(&tc).unwrap_or_default()) + })); + } + if let Some(error) = chunk_state.error { + chunks.push(RedisStreamChunk::Error(error)); + } + + let entries = chunks.into_iter().map(|chunk| chunk.into()).collect(); + let _ = self.add_to_redis_stream(stream_key, entries).await; + } + + async fn notify_start_of_redis_stream( &mut self, stream_key: &str, - data: Vec<(&str, String)>, ) -> Result<(), fred::prelude::Error> { - self.redis.xadd(stream_key, false, None, "*", data).await - } - - async fn flush_and_reset_chunk(&mut self, stream_key: &str) { - let chunk = std::mem::take(&mut self.current_chunk); - if let Ok(data) = RedisStreamChunk::Data(chunk).try_into() { - let _ = self.add_to_redis_stream(stream_key, data).await; - } + let entries = vec![RedisStreamChunk::Start.into()]; + self.add_to_redis_stream(stream_key, entries).await } - async fn mark_end_of_redis_stream( + async fn notify_end_of_redis_stream( &mut self, stream_key: &str, ) -> Result<(), fred::prelude::Error> { - let data = RedisStreamChunk::End.try_into().expect("Should convert"); - self.add_to_redis_stream(stream_key, data).await + let entries = vec![RedisStreamChunk::End.into()]; + self.add_to_redis_stream(stream_key, entries).await } -} -impl TryFrom for Vec<(&str, String)> { - type Error = serde_json::Error; - - /// Converts a `RedisStreamChunk` into a vector of key-value pairs, suitable for the Redis client. - fn try_from(chunk: RedisStreamChunk) -> Result { - match chunk { - RedisStreamChunk::Data(data) => { - let mut vec = Vec::with_capacity(3); - vec.push(("type", "data".into())); - if let Some(text) = data.text { - vec.push(("text", text)); - } - if let Some(tool_calls) = data.tool_calls { - vec.push(("tool_calls", serde_json::to_string(&tool_calls)?)); - } - if let Some(error) = data.error { - vec.push(("error", error)); - } - Ok(vec) - } - RedisStreamChunk::End => Ok(vec![("type", "end".into())]), + async fn add_to_redis_stream( + &mut self, + stream_key: &str, + entries: Vec>, + ) -> Result<(), fred::prelude::Error> { + let pipeline = self.redis.pipeline(); + for entry in entries { + let _: () = pipeline.xadd(stream_key, false, None, "*", entry).await?; } + let _: () = pipeline.expire(stream_key, 60, None).await?; + pipeline.all().await } } From 839103d702161570bfe5637481cb174abd72a1b1 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Thu, 21 Aug 2025 23:12:00 -0400 Subject: [PATCH 03/46] server: refactor tool and title generation utils --- server/Cargo.lock | 1 + server/Cargo.toml | 1 + server/src/api/chat.rs | 99 +++++++++----------------- server/src/errors.rs | 3 + server/src/provider.rs | 13 ++-- server/src/provider/anthropic.rs | 33 ++++----- server/src/provider/lorem.rs | 3 +- server/src/provider/openai.rs | 41 +++++------ server/src/tools.rs | 31 +++++++- server/src/utils/generate_title.rs | 109 ++++++++++++----------------- 10 files changed, 161 insertions(+), 173 deletions(-) diff --git a/server/Cargo.lock b/server/Cargo.lock index 6af3994..de142b9 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -383,6 +383,7 @@ dependencies = [ "diesel_as_jsonb", "diesel_async_migrations", "dotenvy", + "dyn-clone", "enum-iterator", "fred", "hex", diff --git a/server/Cargo.toml b/server/Cargo.toml index a6d39bc..3c2bef7 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -33,6 +33,7 @@ diesel-derive-enum = { version = "3.0.0-beta.1", features = ["postgres"] } diesel_as_jsonb = "1.0.1" diesel_async_migrations = "0.15.0" dotenvy = "0.15.7" +dyn-clone = "1.0.19" enum-iterator = "2.1.0" fred = { version = "10.1.0", default-features = false, features = [ "i-keys", diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index 24c8422..2235841 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -20,16 +20,16 @@ use crate::{ auth::ChatRsUserId, db::{ models::{ - AssistantMeta, ChatRsMessageMeta, ChatRsMessageRole, ChatRsProviderType, - ChatRsSessionMeta, NewChatRsMessage, UpdateChatRsSession, + AssistantMeta, ChatRsMessageMeta, ChatRsMessageRole, ChatRsSessionMeta, + NewChatRsMessage, UpdateChatRsSession, }, services::{ChatDbService, ProviderDbService, ToolDbService}, DbConnection, DbPool, }, errors::ApiError, - provider::{build_llm_provider_api, LlmApiProviderSharedOptions, LlmError, LlmTool}, + provider::{build_llm_provider_api, LlmApiProviderSharedOptions, LlmError}, redis::RedisClient, - tools::SendChatToolInput, + tools::{get_llm_tools_from_input, SendChatToolInput}, utils::{generate_title, Encryptor, LlmStreamProcessor, StoredChatRsStream}, }; @@ -71,12 +71,11 @@ pub async fn send_chat_stream( let (provider, api_key_secret) = ProviderDbService::new(&mut db) .get_by_id(&user_id, input.provider_id) .await?; - let provider_type: ChatRsProviderType = provider.provider_type.as_str().try_into()?; let api_key = api_key_secret .map(|secret| encryptor.decrypt_string(&secret.ciphertext, &secret.nonce)) .transpose()?; let provider_api = build_llm_provider_api( - &provider_type, + &provider.provider_type.as_str().try_into()?, provider.base_url.as_deref(), api_key.as_deref(), &http_client, @@ -84,22 +83,13 @@ pub async fn send_chat_stream( )?; // Get the user's chosen tools - let mut llm_tools: Option> = None; - let mut tool_db_service = ToolDbService::new(&mut db); - if let Some(system_tool_input) = input.tools.as_ref().and_then(|t| t.system.as_ref()) { - let system_tools = tool_db_service.find_system_tools_by_user(&user_id).await?; - let system_llm_tools = system_tool_input.get_llm_tools(&system_tools)?; - llm_tools.get_or_insert_default().extend(system_llm_tools); - } - if let Some(external_apis_input) = input.tools.as_ref().and_then(|t| t.external_apis.as_ref()) { - let external_api_tools = tool_db_service - .find_external_api_tools_by_user(&user_id) - .await?; - for tool_input in external_apis_input { - let api_llm_tools = tool_input.into_llm_tools(&external_api_tools)?; - llm_tools.get_or_insert_default().extend(api_llm_tools); + let llm_tools = match input.tools.as_ref() { + Some(tool_input) => { + let mut tool_db_service = ToolDbService::new(&mut db); + Some(get_llm_tools_from_input(&user_id, tool_input, &mut tool_db_service).await?) } - } + None => None, + }; // Save user message and generate session title if needed if let Some(user_message) = &input.message { @@ -108,16 +98,11 @@ pub async fn send_chat_stream( &user_id, &session_id, &user_message, - provider_type, + &provider_api, &provider.default_model, - provider.base_url.as_deref(), - api_key.clone(), - &http_client, - &redis, db_pool, ); } - let new_message = ChatDbService::new(&mut db) .save_message(NewChatRsMessage { content: user_message, @@ -174,7 +159,8 @@ pub async fn send_chat_stream( } /// # Start chat stream -/// Send a chat message and start streaming the response +/// Send a chat message and start the streamed assistant response. After +/// the response has started, use the `/stream` endpoint to connect to the SSE stream. #[openapi(tag = "Chat")] #[post("//v2", data = "")] pub async fn send_chat_stream_v2( @@ -202,12 +188,11 @@ pub async fn send_chat_stream_v2( let (provider, api_key_secret) = ProviderDbService::new(&mut db) .get_by_id(&user_id, input.provider_id) .await?; - let provider_type: ChatRsProviderType = provider.provider_type.as_str().try_into()?; let api_key = api_key_secret .map(|secret| encryptor.decrypt_string(&secret.ciphertext, &secret.nonce)) .transpose()?; let provider_api = build_llm_provider_api( - &provider_type, + &provider.provider_type.as_str().try_into()?, provider.base_url.as_deref(), api_key.as_deref(), &http_client, @@ -215,40 +200,26 @@ pub async fn send_chat_stream_v2( )?; // Get the user's chosen tools - let mut llm_tools: Option> = None; - let mut tool_db_service = ToolDbService::new(&mut db); - if let Some(system_tool_input) = input.tools.as_ref().and_then(|t| t.system.as_ref()) { - let system_tools = tool_db_service.find_system_tools_by_user(&user_id).await?; - let system_llm_tools = system_tool_input.get_llm_tools(&system_tools)?; - llm_tools.get_or_insert_default().extend(system_llm_tools); - } - if let Some(external_apis_input) = input.tools.as_ref().and_then(|t| t.external_apis.as_ref()) { - let external_api_tools = tool_db_service - .find_external_api_tools_by_user(&user_id) - .await?; - for tool_input in external_apis_input { - let api_llm_tools = tool_input.into_llm_tools(&external_api_tools)?; - llm_tools.get_or_insert_default().extend(api_llm_tools); + let llm_tools = match input.tools.as_ref() { + Some(tool_input) => { + let mut tool_db_service = ToolDbService::new(&mut db); + Some(get_llm_tools_from_input(&user_id, tool_input, &mut tool_db_service).await?) } - } + None => None, + }; - // Save user message and generate session title if needed + // Generate session title if needed, and save user message to database if let Some(user_message) = &input.message { if current_messages.is_empty() && session.title == DEFAULT_SESSION_TITLE { generate_title( &user_id, &session_id, &user_message, - provider_type, + &provider_api, &provider.default_model, - provider.base_url.as_deref(), - api_key.clone(), - &http_client, - &redis, db_pool, ); } - let new_message = ChatDbService::new(&mut db) .save_message(NewChatRsMessage { content: user_message, @@ -336,8 +307,7 @@ pub async fn connect_to_chat_stream( let (_, prev_values): (String, Vec<(String, HashMap)>) = redis .xread::>, _, _>(None, None, &stream_key, "0-0") .await? - .ok_or(LlmError::StreamNotFound)? - .pop() + .and_then(|mut streams| streams.pop()) .ok_or(LlmError::StreamNotFound)?; let last_event = prev_values.last().cloned(); let prev_events_sse = prev_values @@ -378,7 +348,7 @@ pub async fn connect_to_chat_stream( drop(tx); }); - // Send stream of events from Redis to the client, starting with all previous events + // Send stream of events from Redis to the client, starting with all previous events and then new events let stream = prev_events_stream.chain(ReceiverStream::new(rx)).boxed(); Ok(EventStream::from(stream)) } @@ -389,20 +359,19 @@ async fn get_next_event( last_event_id: &str, tx: &tokio::sync::mpsc::Sender, ) -> Result, LlmError> { - let next_stream_value: Option)>)>> = tokio::select! { - next_value = redis.xread::<_, _, _>(Some(1), Some(5_000), stream_key, last_event_id) => next_value?, - _ = tx.closed() => { - println!("Client disconnected"); - return Ok(None); + let (_, mut events): (String, Vec<(String, HashMap)>) = tokio::select! { + next_value = redis.xread::>, _, _>(Some(1), Some(8_000), stream_key, last_event_id) => { + match next_value?.as_mut().and_then(|streams| streams.pop()) { + Some(s) => s, + None => return Ok(None), + } }, + _ = tx.closed() => return Ok(None) }; - match next_stream_value - .and_then(|mut v| v.pop()) - .and_then(|(_, mut events)| events.pop()) - { + match events.pop() { Some((id, event)) => { if event.get("type").is_some_and(|t| t == "end") { - return Ok(None); // handle `end` event + return Ok(None); } Ok(Some(id.clone()).zip(convert_redis_event_to_sse((id, event)))) } diff --git a/server/src/errors.rs b/server/src/errors.rs index 958210f..7adefed 100644 --- a/server/src/errors.rs +++ b/server/src/errors.rs @@ -1,3 +1,4 @@ +use diesel_async::pooled_connection::deadpool; use rocket::{ catch, catchers, response::{self, Responder}, @@ -13,6 +14,8 @@ use crate::{provider::LlmError, tools::ToolError}; pub enum ApiError { #[error(transparent)] Db(#[from] diesel::result::Error), + #[error(transparent)] + DbPool(#[from] deadpool::PoolError), #[error("Authentication error: {0}")] Authentication(String), #[error("Redis error: {0}")] diff --git a/server/src/provider.rs b/server/src/provider.rs index 7ee1235..29bfa9b 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -6,6 +6,7 @@ pub mod openai; use std::pin::Pin; +use dyn_clone::DynClone; use rocket::{async_trait, futures::Stream}; use schemars::JsonSchema; use uuid::Uuid; @@ -99,7 +100,7 @@ pub enum LlmToolType { /// Unified API for LLM providers #[async_trait] -pub trait LlmApiProvider { +pub trait LlmApiProvider: Send + Sync + DynClone { /// Stream a chat response from the provider async fn chat_stream( &self, @@ -120,13 +121,13 @@ pub trait LlmApiProvider { } /// Build the LLM API to make calls to the provider -pub fn build_llm_provider_api<'a>( +pub fn build_llm_provider_api( provider_type: &ChatRsProviderType, - base_url: Option<&'a str>, - api_key: Option<&'a str>, + base_url: Option<&str>, + api_key: Option<&str>, http_client: &reqwest::Client, - redis: &'a fred::prelude::Client, -) -> Result, LlmError> { + redis: &fred::prelude::Client, +) -> Result, LlmError> { match provider_type { ChatRsProviderType::Openai => Ok(Box::new(OpenAIProvider::new( http_client, diff --git a/server/src/provider/anthropic.rs b/server/src/provider/anthropic.rs index dc4a631..078acec 100644 --- a/server/src/provider/anthropic.rs +++ b/server/src/provider/anthropic.rs @@ -18,26 +18,27 @@ const MESSAGES_API_URL: &str = "https://api.anthropic.com/v1/messages"; const API_VERSION: &str = "2023-06-01"; /// Anthropic chat provider -pub struct AnthropicProvider<'a> { +#[derive(Debug, Clone)] +pub struct AnthropicProvider { client: reqwest::Client, - redis: &'a fred::prelude::Client, - api_key: &'a str, + redis: fred::prelude::Client, + api_key: String, } -impl<'a> AnthropicProvider<'a> { +impl AnthropicProvider { pub fn new( http_client: &reqwest::Client, - redis: &'a fred::prelude::Client, - api_key: &'a str, + redis: &fred::prelude::Client, + api_key: &str, ) -> Self { Self { client: http_client.clone(), - redis, - api_key, + redis: redis.clone(), + api_key: api_key.to_string(), } } - fn build_messages( + fn build_messages<'a>( &self, messages: &'a [ChatRsMessage], ) -> (Vec>, Option<&'a str>) { @@ -104,7 +105,7 @@ impl<'a> AnthropicProvider<'a> { (anthropic_messages, system_prompt) } - fn build_tools(&self, tools: &'a [LlmTool]) -> Vec> { + fn build_tools<'a>(&self, tools: &'a [LlmTool]) -> Vec> { tools .iter() .map(|tool| AnthropicTool { @@ -252,7 +253,7 @@ impl<'a> AnthropicProvider<'a> { } #[async_trait] -impl<'a> LlmApiProvider for AnthropicProvider<'a> { +impl LlmApiProvider for AnthropicProvider { async fn chat_stream( &self, messages: Vec, @@ -277,7 +278,7 @@ impl<'a> LlmApiProvider for AnthropicProvider<'a> { .post(MESSAGES_API_URL) .header("anthropic-version", API_VERSION) .header("content-type", "application/json") - .header("x-api-key", self.api_key) + .header("x-api-key", &self.api_key) .json(&request) .send() .await @@ -318,7 +319,7 @@ impl<'a> LlmApiProvider for AnthropicProvider<'a> { .post(MESSAGES_API_URL) .header("anthropic-version", API_VERSION) .header("content-type", "application/json") - .header("x-api-key", self.api_key) + .header("x-api-key", &self.api_key) .json(&request) .send() .await @@ -333,16 +334,16 @@ impl<'a> LlmApiProvider for AnthropicProvider<'a> { ))); } - let anthropic_response: AnthropicResponse = response + let mut anthropic_response: AnthropicResponse = response .json() .await .map_err(|e| LlmError::AnthropicError(format!("Failed to parse response: {}", e)))?; let text = anthropic_response .content - .first() + .get_mut(0) .and_then(|block| match block { - AnthropicResponseContentBlock::Text { text } => Some(text.clone()), + AnthropicResponseContentBlock::Text { text } => Some(std::mem::take(text)), _ => None, }) .ok_or_else(|| LlmError::NoResponse)?; diff --git a/server/src/provider/lorem.rs b/server/src/provider/lorem.rs index 0d87efd..a3a8f24 100644 --- a/server/src/provider/lorem.rs +++ b/server/src/provider/lorem.rs @@ -17,11 +17,12 @@ use crate::{ }; /// A test/dummy provider that streams 'lorem ipsum...' +#[derive(Debug, Clone)] pub struct LoremProvider { pub config: LoremConfig, } -#[derive(JsonSchema)] +#[derive(Debug, Clone, JsonSchema)] pub struct LoremConfig { pub interval: u32, } diff --git a/server/src/provider/openai.rs b/server/src/provider/openai.rs index b8b0a3f..b6c0a68 100644 --- a/server/src/provider/openai.rs +++ b/server/src/provider/openai.rs @@ -16,29 +16,30 @@ const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1"; const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1"; /// OpenAI chat provider -pub struct OpenAIProvider<'a> { +#[derive(Debug, Clone)] +pub struct OpenAIProvider { client: reqwest::Client, - redis: &'a fred::prelude::Client, - api_key: &'a str, - base_url: &'a str, + redis: fred::prelude::Client, + api_key: String, + base_url: String, } -impl<'a> OpenAIProvider<'a> { +impl OpenAIProvider { pub fn new( http_client: &reqwest::Client, - redis: &'a fred::prelude::Client, - api_key: &'a str, - base_url: Option<&'a str>, + redis: &fred::prelude::Client, + api_key: &str, + base_url: Option<&str>, ) -> Self { Self { client: http_client.clone(), - redis, - api_key, - base_url: base_url.unwrap_or(OPENAI_API_BASE_URL), + redis: redis.clone(), + api_key: api_key.to_owned(), + base_url: base_url.unwrap_or(OPENAI_API_BASE_URL).to_owned(), } } - fn build_messages(&self, messages: &'a [ChatRsMessage]) -> Vec> { + fn build_messages<'a>(&self, messages: &'a [ChatRsMessage]) -> Vec> { messages .iter() .map(|message| { @@ -77,7 +78,7 @@ impl<'a> OpenAIProvider<'a> { .collect() } - fn build_tools(&self, tools: &'a [LlmTool]) -> Vec { + fn build_tools<'a>(&self, tools: &'a [LlmTool]) -> Vec> { tools .iter() .map(|tool| OpenAITool { @@ -189,7 +190,7 @@ impl<'a> OpenAIProvider<'a> { } #[async_trait] -impl<'a> LlmApiProvider for OpenAIProvider<'a> { +impl LlmApiProvider for OpenAIProvider { async fn chat_stream( &self, messages: Vec, @@ -269,16 +270,16 @@ impl<'a> LlmApiProvider for OpenAIProvider<'a> { ))); } - let openai_response: OpenAIResponse = response + let mut openai_response: OpenAIResponse = response .json() .await .map_err(|e| LlmError::OpenAIError(format!("Failed to parse response: {}", e)))?; let text = openai_response .choices - .first() - .and_then(|choice| choice.message.as_ref()) - .and_then(|message| message.content.as_ref()) + .get_mut(0) + .and_then(|choice| choice.message.as_mut()) + .and_then(|message| message.content.take()) .ok_or(LlmError::NoResponse)?; if let Some(usage) = openai_response.usage { @@ -286,14 +287,14 @@ impl<'a> LlmApiProvider for OpenAIProvider<'a> { println!("Prompt usage: {:?}", usage); } - Ok(text.clone()) + Ok(text) } async fn list_models(&self) -> Result, LlmError> { let models_service = ModelsDevService::new(self.redis.clone(), self.client.clone()); let models = models_service .list_models({ - match self.base_url { + match self.base_url.as_str() { OPENROUTER_API_BASE_URL => ModelsDevServiceProvider::OpenRouter, _ => ModelsDevServiceProvider::OpenAI, } diff --git a/server/src/tools.rs b/server/src/tools.rs index e2d1d11..7d81cab 100644 --- a/server/src/tools.rs +++ b/server/src/tools.rs @@ -9,7 +9,11 @@ pub use { system::{ChatRsSystemToolConfig, SystemToolInput}, }; -use schemars::JsonSchema; +use { + crate::{db::services::ToolDbService, errors::ApiError, provider::LlmTool}, + schemars::JsonSchema, + uuid::Uuid, +}; /// User configuration of tools when sending a chat message #[derive(Debug, Default, PartialEq, JsonSchema, serde::Serialize, serde::Deserialize)] @@ -17,3 +21,28 @@ pub struct SendChatToolInput { pub system: Option, pub external_apis: Option>, } + +/// Get all tools from the user's input in LLM generic format +pub async fn get_llm_tools_from_input( + user_id: &Uuid, + input: &SendChatToolInput, + tool_db_service: &mut ToolDbService<'_>, +) -> Result, ApiError> { + let mut llm_tools = Vec::with_capacity(5); + if let Some(ref system_tool_input) = input.system { + let system_tools = tool_db_service.find_system_tools_by_user(&user_id).await?; + let system_llm_tools = system_tool_input.get_llm_tools(&system_tools)?; + llm_tools.extend(system_llm_tools); + } + if let Some(ref external_apis_input) = input.external_apis { + let external_api_tools = tool_db_service + .find_external_api_tools_by_user(&user_id) + .await?; + for tool_input in external_apis_input { + let api_llm_tools = tool_input.into_llm_tools(&external_api_tools)?; + llm_tools.extend(api_llm_tools); + } + } + + Ok(llm_tools) +} diff --git a/server/src/utils/generate_title.rs b/server/src/utils/generate_title.rs index f61589a..c8ce15a 100644 --- a/server/src/utils/generate_title.rs +++ b/server/src/utils/generate_title.rs @@ -1,87 +1,68 @@ +use std::ops::Deref; + use uuid::Uuid; use crate::{ - db::{ - models::{ChatRsProviderType, UpdateChatRsSession}, - services::ChatDbService, - DbConnection, DbPool, - }, - provider::{build_llm_provider_api, LlmApiProviderSharedOptions, DEFAULT_TEMPERATURE}, + db::{models::UpdateChatRsSession, services::ChatDbService, DbConnection, DbPool}, + errors::ApiError, + provider::{LlmApiProvider, LlmApiProviderSharedOptions, DEFAULT_TEMPERATURE}, }; -const TITLE_PROMPT: &str = "This is the first message sent by a human in a session with an AI chatbot. Please generate a short title for the session (max 6 words) in plain text"; +const TITLE_TOKENS: u32 = 20; +const TITLE_PROMPT: &str = + "This is the first message sent by a human in a chat session with an AI chatbot. \ + Please generate a short title for the session (3-7 words) in plain text \ + (no quotes or prefixes)."; -/// Spawns a task to generate a title for the chat session +/// Spawn a task to generate a title for the chat session pub fn generate_title( user_id: &Uuid, session_id: &Uuid, user_message: &str, - provider_type: ChatRsProviderType, + provider: &Box, model: &str, - base_url: Option<&str>, - api_key: Option, - http_client: &reqwest::Client, - redis: &fred::prelude::Client, pool: &DbPool, ) { let user_id = user_id.to_owned(); let session_id = session_id.to_owned(); let user_message = user_message.to_owned(); + let provider = dyn_clone::clone_box(provider.deref()); let model = model.to_owned(); - let base_url = base_url.map(|url| url.to_owned()); - let http_client = http_client.clone(); - let redis = redis.clone(); let pool = pool.clone(); tokio::spawn(async move { - let Ok(conn) = pool.get().await else { - rocket::error!("Couldn't get database connection"); - return; - }; - let mut db = DbConnection(conn); - let Ok(provider) = build_llm_provider_api( - &provider_type, - base_url.as_deref(), - api_key.as_deref(), - &http_client, - &redis, - ) else { - rocket::warn!("Error creating provider for chat {}", session_id); - return; - }; - - let provider_options = LlmApiProviderSharedOptions { - model, - temperature: Some(DEFAULT_TEMPERATURE), - max_tokens: Some(20), - }; - let provider_response = provider - .prompt( - &format!("{}: \"{}\"", TITLE_PROMPT, user_message), - &provider_options, - ) - .await; - - match provider_response { - Ok(title) => { - rocket::info!("Generated title for chat {}", session_id); - if let Err(e) = ChatDbService::new(&mut db) - .update_session( - &user_id, - &session_id, - UpdateChatRsSession { - title: Some(title.trim()), - ..Default::default() - }, - ) - .await - { - rocket::warn!("Error saving title for chat {}: {}", session_id, e); - }; - } - Err(e) => { - rocket::warn!("Error generating title for chat {}: {}", session_id, e); - } + if let Err(err) = generate(user_id, session_id, user_message, provider, model, pool).await { + rocket::warn!("Failed to generate title: {}", err); } }); } + +async fn generate( + user_id: Uuid, + session_id: Uuid, + user_message: String, + provider: Box, + model: String, + pool: DbPool, +) -> Result<(), ApiError> { + let provider_options = LlmApiProviderSharedOptions { + model, + temperature: Some(DEFAULT_TEMPERATURE), + max_tokens: Some(TITLE_TOKENS), + }; + let message = format!("{}: \"{}\"", TITLE_PROMPT, user_message); + let title = provider.prompt(&message, &provider_options).await?; + + let mut db = DbConnection(pool.get().await?); + ChatDbService::new(&mut db) + .update_session( + &user_id, + &session_id, + UpdateChatRsSession { + title: Some(title.trim()), + ..Default::default() + }, + ) + .await?; + Ok(()) +} From 3d0cbff0ee0216df905a79d21b1a7dc113501e66 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 22 Aug 2025 01:43:51 -0400 Subject: [PATCH 04/46] server: refactor and organize Redis stream logic --- server/src/api/chat.rs | 145 +++++------------- server/src/lib.rs | 1 + server/src/provider.rs | 4 + server/src/stream.rs | 5 + .../llm_stream.rs => stream/llm_writer.rs} | 39 +++-- server/src/stream/reader.rs | 115 ++++++++++++++ server/src/utils.rs | 2 - 7 files changed, 186 insertions(+), 125 deletions(-) create mode 100644 server/src/stream.rs rename server/src/{utils/llm_stream.rs => stream/llm_writer.rs} (86%) create mode 100644 server/src/stream/reader.rs diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index 2235841..f60045b 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -1,9 +1,9 @@ -use std::{borrow::Cow, collections::HashMap, pin::Pin}; +use std::{borrow::Cow, pin::Pin}; -use fred::prelude::{KeysInterface, StreamsInterface}; +use fred::prelude::KeysInterface; use rocket::{ - futures::{Stream, StreamExt}, - post, + futures::{stream, Stream, StreamExt}, + get, post, response::stream::{Event, EventStream}, serde::json::Json, Route, State, @@ -29,8 +29,9 @@ use crate::{ errors::ApiError, provider::{build_llm_provider_api, LlmApiProviderSharedOptions, LlmError}, redis::RedisClient, + stream::{LlmStreamWriter, SseStreamReader}, tools::{get_llm_tools_from_input, SendChatToolInput}, - utils::{generate_title, Encryptor, LlmStreamProcessor, StoredChatRsStream}, + utils::{generate_title, Encryptor, StoredChatRsStream}, }; pub fn get_routes(settings: &OpenApiSettings) -> (Vec, OpenApi) { @@ -51,7 +52,7 @@ pub struct SendChatInput<'a> { /// Send a chat message and stream the response #[openapi(tag = "Chat")] -#[post("/", data = "")] +#[post("//v1", data = "")] pub async fn send_chat_stream( user_id: ChatRsUserId, db_pool: &State, @@ -159,15 +160,15 @@ pub async fn send_chat_stream( } /// # Start chat stream -/// Send a chat message and start the streamed assistant response. After -/// the response has started, use the `/stream` endpoint to connect to the SSE stream. +/// Send a chat message and start the streamed assistant response. After the response +/// has started, use the `//stream` endpoint to connect to the SSE stream. #[openapi(tag = "Chat")] -#[post("//v2", data = "")] +#[post("/", data = "")] pub async fn send_chat_stream_v2( user_id: ChatRsUserId, db_pool: &State, mut db: DbConnection, - redis: RedisClient, + redis: &State, encryptor: &State, http_client: &State, session_id: Uuid, @@ -184,7 +185,7 @@ pub async fn send_chat_stream_v2( .get_session_with_messages(&user_id, &session_id) .await?; - // Build the chat provider + // Build the LLM provider let (provider, api_key_secret) = ProviderDbService::new(&mut db) .get_by_id(&user_id, input.provider_id) .await?; @@ -196,17 +197,16 @@ pub async fn send_chat_stream_v2( provider.base_url.as_deref(), api_key.as_deref(), &http_client, - &redis, + redis.next(), )?; // Get the user's chosen tools - let llm_tools = match input.tools.as_ref() { - Some(tool_input) => { - let mut tool_db_service = ToolDbService::new(&mut db); - Some(get_llm_tools_from_input(&user_id, tool_input, &mut tool_db_service).await?) - } - None => None, - }; + let mut llm_tools = None; + if let Some(tool_input) = input.tools.as_ref() { + let mut tool_db_service = ToolDbService::new(&mut db); + let tools = get_llm_tools_from_input(&user_id, tool_input, &mut tool_db_service).await?; + llm_tools = Some(tools); + } // Generate session title if needed, and save user message to database if let Some(user_message) = &input.message { @@ -253,18 +253,18 @@ pub async fn send_chat_stream_v2( } } - // Get the provider's stream response, and spawn a task to stream it to Redis - // and save the response to the database on completion + // Get the provider's stream response let stream = provider_api .chat_stream(current_messages, llm_tools, &input.provider_options) .await?; - let stream_processor = LlmStreamProcessor::new(&redis); let provider_id = input.provider_id; let provider_options = input.provider_options.clone(); + + // Spawn a task to stream the response to Redis and save it to the database on completion + let stream_writer = LlmStreamWriter::new(&redis); tokio::spawn(async move { - let (content, tool_calls, usage, _) = stream_processor - .process_llm_stream(&stream_key, stream) - .await; + let (content, tool_calls, usage, _) = + stream_writer.process_stream(&stream_key, stream).await; if let Err(e) = ChatDbService::new(&mut db) .save_message(NewChatRsMessage { session_id: &session_id, @@ -295,103 +295,32 @@ pub async fn send_chat_stream_v2( /// # Connect to chat stream /// Connect to an ongoing chat stream and stream the assistant response #[openapi(tag = "Chat")] -#[post("//stream")] +#[get("//stream")] pub async fn connect_to_chat_stream( user_id: ChatRsUserId, - redis: RedisClient, + redis: &State, session_id: Uuid, ) -> Result + Send>>>, ApiError> { let stream_key = format!("user:{}:chat:{}", user_id.0, session_id); + let stream_reader = SseStreamReader::new(&redis); - // Get all previous events from the Redis stream - let (_, prev_values): (String, Vec<(String, HashMap)>) = redis - .xread::>, _, _>(None, None, &stream_key, "0-0") - .await? - .and_then(|mut streams| streams.pop()) - .ok_or(LlmError::StreamNotFound)?; - let last_event = prev_values.last().cloned(); - let prev_events_sse = prev_values - .into_iter() - .filter_map(convert_redis_event_to_sse); - let prev_events_stream = rocket::futures::stream::iter(prev_events_sse); - - // If `end` event already received, just return previous events - if let Some((_, ref data)) = last_event { - if data.get("type").is_some_and(|t| t == "end") { - return Ok(EventStream::from(prev_events_stream.boxed())); - } + // Get all previous events from the Redis stream, and return them if we're already at the end of the stream + let (prev_events, last_event_id, is_end) = stream_reader.get_prev_events(&stream_key).await?; + let prev_events_stream = stream::iter(prev_events); + if is_end { + return Ok(EventStream::from(prev_events_stream.boxed())); } // Spawn a task to receive new events from Redis and add them to the channel let (tx, rx) = tokio::sync::mpsc::channel::(50); tokio::spawn(async move { - let mut last_event_id = last_event.map(|(id, _)| id).unwrap_or_else(|| "0-0".into()); - loop { - match get_next_event(&redis, &stream_key, &last_event_id, &tx).await { - Ok(Some((id, event))) => { - last_event_id = id; - if let Err(_) = tx.send(event).await { - break; // client disconnected, stop sending events - } - } - Ok(None) => { - tx.send(Event::empty().event("end")).await.ok(); - break; // No more events, end of stream - } - Err(e) => { - let event = Event::data(format!("Error: {}", e)).event("error"); - tx.send(event).await.ok(); - break; - } - } - } + stream_reader + .stream_events(&stream_key, &last_event_id, &tx) + .await; drop(tx); }); - // Send stream of events from Redis to the client, starting with all previous events and then new events + // Send stream to client let stream = prev_events_stream.chain(ReceiverStream::new(rx)).boxed(); Ok(EventStream::from(stream)) } - -async fn get_next_event( - redis: &RedisClient, - stream_key: &str, - last_event_id: &str, - tx: &tokio::sync::mpsc::Sender, -) -> Result, LlmError> { - let (_, mut events): (String, Vec<(String, HashMap)>) = tokio::select! { - next_value = redis.xread::>, _, _>(Some(1), Some(8_000), stream_key, last_event_id) => { - match next_value?.as_mut().and_then(|streams| streams.pop()) { - Some(s) => s, - None => return Ok(None), - } - }, - _ = tx.closed() => return Ok(None) - }; - match events.pop() { - Some((id, event)) => { - if event.get("type").is_some_and(|t| t == "end") { - return Ok(None); - } - Ok(Some(id.clone()).zip(convert_redis_event_to_sse((id, event)))) - } - None => Ok(None), - } -} - -fn convert_redis_event_to_sse((id, event): (String, HashMap)) -> Option { - let mut r#type = None; - let mut data = None; - for (key, value) in event { - match key.as_str() { - "type" => r#type = Some(value), - "data" => data = Some(value), - _ => {} - } - } - if let Some(r#type) = r#type { - Some(Event::data(data.unwrap_or_default()).event(r#type).id(id)) - } else { - None - } -} diff --git a/server/src/lib.rs b/server/src/lib.rs index 473e867..4286e4e 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -6,6 +6,7 @@ pub mod errors; pub mod provider; pub mod provider_models; pub mod redis; +pub mod stream; pub mod tools; pub mod utils; pub mod web; diff --git a/server/src/provider.rs b/server/src/provider.rs index 29bfa9b..dccdec9 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -41,6 +41,10 @@ pub enum LlmError { AlreadyStreaming, #[error("No ongoing stream for this session")] StreamNotFound, + #[error("Missing event in stream")] + NoStreamEvent, + #[error("Client disconnected")] + ClientDisconnected, #[error("Encryption error")] EncryptionError, #[error("Decryption error")] diff --git a/server/src/stream.rs b/server/src/stream.rs new file mode 100644 index 0000000..df4e981 --- /dev/null +++ b/server/src/stream.rs @@ -0,0 +1,5 @@ +mod llm_writer; +mod reader; + +pub use llm_writer::*; +pub use reader::*; diff --git a/server/src/utils/llm_stream.rs b/server/src/stream/llm_writer.rs similarity index 86% rename from server/src/utils/llm_stream.rs rename to server/src/stream/llm_writer.rs index ca896c4..1c18be2 100644 --- a/server/src/utils/llm_stream.rs +++ b/server/src/stream/llm_writer.rs @@ -12,14 +12,18 @@ use crate::{ provider::{LlmApiStream, LlmError, LlmUsage}, }; -const MAX_CHUNK_SIZE: usize = 1000; +/// Interval at which chunks are flushed to Redis. const FLUSH_INTERVAL: Duration = Duration::from_millis(500); - -/// Utility struct for processing an incoming LLM stream and intermittently -/// flushing the data to a Redis stream. -#[derive(Debug, Default)] -pub struct LlmStreamProcessor { - redis: fred::prelude::Client, +/// Max accumulated size of the text before it is automatically flushed to Redis. +const MAX_CHUNK_SIZE: usize = 1000; +/// Expiration in seconds set on the Redis stream (normally, the Redis stream will be deleted before this) +const STREAM_EXPIRE: i64 = 30; + +/// Utility for processing an incoming LLM response stream and intermittently +/// writing chunks to a Redis stream. +#[derive(Debug)] +pub struct LlmStreamWriter { + redis: fred::prelude::Pool, /// The current chunk of data being processed. current_chunk: ChunkState, /// Accumulated text response from the assistant. @@ -32,6 +36,7 @@ pub struct LlmStreamProcessor { usage: Option, } +/// Internal state #[derive(Debug, Default)] struct ChunkState { text: Option, @@ -57,17 +62,21 @@ impl From for HashMap { } } -impl LlmStreamProcessor { - pub fn new(redis: &fred::prelude::Client) -> Self { - LlmStreamProcessor { +impl LlmStreamWriter { + pub fn new(redis: &fred::prelude::Pool) -> Self { + LlmStreamWriter { redis: redis.clone(), - ..Default::default() + current_chunk: ChunkState::default(), + complete_text: None, + tool_calls: None, + errors: None, + usage: None, } } /// Process the incoming stream from the LLM provider, intermittently - /// flush to Redis stream, and return the accumulated results. - pub async fn process_llm_stream( + /// flushing chunks to a Redis stream, and return the final accumulated response. + pub async fn process_stream( mut self, stream_key: &str, mut stream: LlmApiStream, @@ -200,11 +209,11 @@ impl LlmStreamProcessor { stream_key: &str, entries: Vec>, ) -> Result<(), fred::prelude::Error> { - let pipeline = self.redis.pipeline(); + let pipeline = self.redis.next().pipeline(); for entry in entries { let _: () = pipeline.xadd(stream_key, false, None, "*", entry).await?; } - let _: () = pipeline.expire(stream_key, 60, None).await?; + let _: () = pipeline.expire(stream_key, STREAM_EXPIRE, None).await?; pipeline.all().await } } diff --git a/server/src/stream/reader.rs b/server/src/stream/reader.rs new file mode 100644 index 0000000..4bc2830 --- /dev/null +++ b/server/src/stream/reader.rs @@ -0,0 +1,115 @@ +use std::collections::HashMap; + +use fred::prelude::StreamsInterface; +use rocket::response::stream::Event; +use tokio::sync::mpsc; + +use crate::provider::LlmError; + +/// Timeout for the blocking `xread` command. +const XREAD_BLOCK_TIMEOUT: u64 = 10_000; // 10 seconds + +/// Utility for reading SSE events from a Redis stream. +pub struct SseStreamReader { + redis: fred::prelude::Pool, +} + +impl SseStreamReader { + pub fn new(redis: &fred::prelude::Pool) -> Self { + Self { + redis: redis.clone(), + } + } + + /// Retrieve the previous events from the given Redis stream. + /// Returns a tuple containing the previous events, the last event ID, and a boolean + /// indicating if the stream has already ended. + pub async fn get_prev_events(&self, key: &str) -> Result<(Vec, String, bool), LlmError> { + let (_, prev_events): (String, Vec<(String, HashMap)>) = self + .redis + .xread::>, _, _>(None, None, key, "0-0") + .await? + .and_then(|mut streams| streams.pop()) // should only be 1 stream since we're sending 1 key in the command + .ok_or(LlmError::StreamNotFound)?; + let (last_event_id, is_end) = prev_events + .last() + .map(|(id, data)| (id.to_owned(), data.get("type").is_some_and(|t| t == "end"))) + .unwrap_or_else(|| ("0-0".into(), false)); + let sse_events = prev_events + .into_iter() + .map(convert_redis_event_to_sse) + .collect::>(); + + Ok((sse_events, last_event_id, is_end)) + } + + /// Stream the events from the given Redis stream using a blocking `xread` command. + pub async fn stream_events(&self, key: &str, last_event_id: &str, tx: &mpsc::Sender) { + let mut last_event_id = last_event_id.to_owned(); + loop { + match self.get_next_event(key, &mut last_event_id, tx).await { + Ok((id, data, is_end)) => { + let event = convert_redis_event_to_sse((id, data)); + if let Err(_) = tx.send(event).await { + break; // client disconnected + } + if is_end { + break; // reached end of stream + } + } + Err(err) => { + let event = Event::data(format!("Error: {}", err)).event("error"); + tx.send(event).await.ok(); + break; + } + } + } + } + + /// Get the next event from the given Redis stream using a blocking `xread` command. + /// - Updates the last event ID + /// - Cancels waiting for the next event if the client disconnects + /// - Returns the event ID, data, and a `bool` indicating whether it's the ending event + async fn get_next_event( + &self, + key: &str, + last_event_id: &mut String, + tx: &mpsc::Sender, + ) -> Result<(String, HashMap, bool), LlmError> { + let (_, mut events): (String, Vec<(String, HashMap)>) = tokio::select! { + res = self.redis.xread::>, _, _>(Some(1), Some(XREAD_BLOCK_TIMEOUT), key, &*last_event_id) => { + match res?.as_mut().and_then(|streams| streams.pop()) { + Some(stream) => stream, + None => return Err(LlmError::StreamNotFound), + } + }, + _ = tx.closed() => return Err(LlmError::ClientDisconnected) + }; + match events.pop() { + Some((id, data)) => { + *last_event_id = id.clone(); + let is_end = data.get("type").is_some_and(|t| t == "end"); + Ok((id, data, is_end)) + } + None => Err(LlmError::NoStreamEvent), + } + } +} + +/// Convert a Redis stream event into an SSE event. Expects the event hash map to contain +/// a "type" and "data" field (e.g. serialized using the appropriate serde tag and content). +fn convert_redis_event_to_sse((id, event): (String, HashMap)) -> Event { + let mut r#type: Option = None; + let mut data: Option = None; + for (key, value) in event { + match key.as_str() { + "type" => r#type = Some(value), + "data" => data = Some(value), + _ => {} + } + } + + Event::data(data.unwrap_or_default()) + .event(r#type.unwrap_or_else(|| "unknown".into())) + .id(id) +} diff --git a/server/src/utils.rs b/server/src/utils.rs index 46b6f3c..32b9c14 100644 --- a/server/src/utils.rs +++ b/server/src/utils.rs @@ -2,7 +2,6 @@ mod encryption; mod full_text_search; mod generate_title; mod json_logging; -mod llm_stream; mod sender_with_logging; mod stored_stream; @@ -10,6 +9,5 @@ pub use encryption::*; pub use full_text_search::*; pub use generate_title::*; pub use json_logging::*; -pub use llm_stream::*; pub use sender_with_logging::*; pub use stored_stream::*; From 1eb96a5ed916752a4b820f45ed3e57e8319a91b3 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 22 Aug 2025 02:13:27 -0400 Subject: [PATCH 05/46] server: use Redis pool instead of client for consistency --- server/src/api/chat.rs | 116 +------------------------------ server/src/api/provider.rs | 3 +- server/src/api/session.rs | 33 +-------- server/src/provider.rs | 2 +- server/src/provider/anthropic.rs | 10 +-- server/src/provider/openai.rs | 6 +- server/src/provider_models.rs | 13 ++-- server/src/redis.rs | 41 ++--------- 8 files changed, 26 insertions(+), 198 deletions(-) diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index f60045b..b6fe03b 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -28,14 +28,13 @@ use crate::{ }, errors::ApiError, provider::{build_llm_provider_api, LlmApiProviderSharedOptions, LlmError}, - redis::RedisClient, stream::{LlmStreamWriter, SseStreamReader}, tools::{get_llm_tools_from_input, SendChatToolInput}, - utils::{generate_title, Encryptor, StoredChatRsStream}, + utils::{generate_title, Encryptor}, }; pub fn get_routes(settings: &OpenApiSettings) -> (Vec, OpenApi) { - openapi_get_routes_spec![settings: send_chat_stream, send_chat_stream_v2, connect_to_chat_stream] + openapi_get_routes_spec![settings: send_chat_stream_v2, connect_to_chat_stream] } #[derive(JsonSchema, serde::Deserialize)] @@ -50,115 +49,6 @@ pub struct SendChatInput<'a> { tools: Option, } -/// Send a chat message and stream the response -#[openapi(tag = "Chat")] -#[post("//v1", data = "")] -pub async fn send_chat_stream( - user_id: ChatRsUserId, - db_pool: &State, - mut db: DbConnection, - redis: RedisClient, - encryptor: &State, - http_client: &State, - session_id: Uuid, - mut input: Json>, -) -> Result + Send>>>, ApiError> { - // Check session exists and user is owner, get message history - let (session, mut current_messages) = ChatDbService::new(&mut db) - .get_session_with_messages(&user_id, &session_id) - .await?; - - // Build the chat provider - let (provider, api_key_secret) = ProviderDbService::new(&mut db) - .get_by_id(&user_id, input.provider_id) - .await?; - let api_key = api_key_secret - .map(|secret| encryptor.decrypt_string(&secret.ciphertext, &secret.nonce)) - .transpose()?; - let provider_api = build_llm_provider_api( - &provider.provider_type.as_str().try_into()?, - provider.base_url.as_deref(), - api_key.as_deref(), - &http_client, - &redis, - )?; - - // Get the user's chosen tools - let llm_tools = match input.tools.as_ref() { - Some(tool_input) => { - let mut tool_db_service = ToolDbService::new(&mut db); - Some(get_llm_tools_from_input(&user_id, tool_input, &mut tool_db_service).await?) - } - None => None, - }; - - // Save user message and generate session title if needed - if let Some(user_message) = &input.message { - if current_messages.is_empty() && session.title == DEFAULT_SESSION_TITLE { - generate_title( - &user_id, - &session_id, - &user_message, - &provider_api, - &provider.default_model, - db_pool, - ); - } - let new_message = ChatDbService::new(&mut db) - .save_message(NewChatRsMessage { - content: user_message, - session_id: &session_id, - role: ChatRsMessageRole::User, - meta: ChatRsMessageMeta::default(), - }) - .await?; - current_messages.push(new_message); - } - - // Update session metadata - if let Some(tool_input) = input.tools.take() { - if session - .meta - .tool_config - .is_none_or(|config| config != tool_input) - { - ChatDbService::new(&mut db) - .update_session( - &user_id, - &session_id, - UpdateChatRsSession { - meta: Some(&ChatRsSessionMeta { - tool_config: Some(tool_input), - }), - ..Default::default() - }, - ) - .await?; - } - } - - // Get the provider's stream response and wrap it in our StoredChatRsStream - let stream = StoredChatRsStream::new( - provider_api - .chat_stream(current_messages, llm_tools, &input.provider_options) - .await?, - input.provider_id, - input.provider_options.clone(), - db_pool.inner().clone(), - redis.clone(), - Some(session_id), - ); - - // Start streaming - let event_stream = stream - .map(|result| match result { - Ok(message) => Event::data(format!(" {message}")).event("chat"), - Err(err) => Event::data(err.to_string()).event("error"), - }) - .boxed(); - Ok(EventStream::from(event_stream)) -} - /// # Start chat stream /// Send a chat message and start the streamed assistant response. After the response /// has started, use the `//stream` endpoint to connect to the SSE stream. @@ -197,7 +87,7 @@ pub async fn send_chat_stream_v2( provider.base_url.as_deref(), api_key.as_deref(), &http_client, - redis.next(), + redis, )?; // Get the user's chosen tools diff --git a/server/src/api/provider.rs b/server/src/api/provider.rs index e66fa02..dc2989c 100644 --- a/server/src/api/provider.rs +++ b/server/src/api/provider.rs @@ -18,7 +18,6 @@ use crate::{ errors::ApiError, provider::build_llm_provider_api, provider_models::LlmModel, - redis::RedisClient, utils::Encryptor, }; @@ -54,7 +53,7 @@ async fn get_all_providers( async fn list_models( user_id: ChatRsUserId, mut db: DbConnection, - redis: RedisClient, + redis: &State, encryptor: &State, http_client: &State, provider_id: i32, diff --git a/server/src/api/session.rs b/server/src/api/session.rs index a67573d..6be63d3 100644 --- a/server/src/api/session.rs +++ b/server/src/api/session.rs @@ -1,4 +1,3 @@ -use fred::prelude::KeysInterface; use rocket::{delete, get, patch, post, serde::json::Json, Route}; use rocket_okapi::{ okapi::openapi3::OpenApi, openapi, openapi_get_routes_spec, settings::OpenApiSettings, @@ -10,16 +9,12 @@ use uuid::Uuid; use crate::{ auth::ChatRsUserId, db::{ - models::{ - AssistantMeta, ChatRsMessage, ChatRsMessageMeta, ChatRsMessageRole, ChatRsSession, - NewChatRsSession, UpdateChatRsSession, - }, + models::{ChatRsMessage, ChatRsSession, NewChatRsSession, UpdateChatRsSession}, services::ChatDbService, DbConnection, }, errors::ApiError, - redis::RedisClient, - utils::{SessionSearchResult, CHAT_CACHE_KEY_PREFIX}, + utils::SessionSearchResult, }; pub fn get_routes(settings: &OpenApiSettings) -> (Vec, OpenApi) { @@ -84,34 +79,12 @@ struct GetSessionResponse { async fn get_session( user_id: ChatRsUserId, mut db: DbConnection, - redis: RedisClient, session_id: Uuid, ) -> Result, ApiError> { - let (session, mut messages) = ChatDbService::new(&mut db) + let (session, messages) = ChatDbService::new(&mut db) .get_session_with_messages(&user_id, &session_id) .await?; - // Check for a cached response if the session is interrupted - let cached_response: Option = redis - .get(format!("{}{}", CHAT_CACHE_KEY_PREFIX, &session_id)) - .await?; - if let Some(interrupted_response) = cached_response { - messages.push(ChatRsMessage { - id: Uuid::new_v4(), - session_id, - role: ChatRsMessageRole::Assistant, - content: interrupted_response, - created_at: chrono::Utc::now(), - meta: ChatRsMessageMeta { - assistant: Some(AssistantMeta { - partial: Some(true), - ..Default::default() - }), - ..Default::default() - }, - }); - } - Ok(Json(GetSessionResponse { session, messages })) } diff --git a/server/src/provider.rs b/server/src/provider.rs index dccdec9..ba9bc30 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -130,7 +130,7 @@ pub fn build_llm_provider_api( base_url: Option<&str>, api_key: Option<&str>, http_client: &reqwest::Client, - redis: &fred::prelude::Client, + redis: &fred::prelude::Pool, ) -> Result, LlmError> { match provider_type { ChatRsProviderType::Openai => Ok(Box::new(OpenAIProvider::new( diff --git a/server/src/provider/anthropic.rs b/server/src/provider/anthropic.rs index 078acec..9239fe7 100644 --- a/server/src/provider/anthropic.rs +++ b/server/src/provider/anthropic.rs @@ -21,16 +21,12 @@ const API_VERSION: &str = "2023-06-01"; #[derive(Debug, Clone)] pub struct AnthropicProvider { client: reqwest::Client, - redis: fred::prelude::Client, + redis: fred::prelude::Pool, api_key: String, } impl AnthropicProvider { - pub fn new( - http_client: &reqwest::Client, - redis: &fred::prelude::Client, - api_key: &str, - ) -> Self { + pub fn new(http_client: &reqwest::Client, redis: &fred::prelude::Pool, api_key: &str) -> Self { Self { client: http_client.clone(), redis: redis.clone(), @@ -357,7 +353,7 @@ impl LlmApiProvider for AnthropicProvider { } async fn list_models(&self) -> Result, LlmError> { - let models_service = ModelsDevService::new(self.redis.clone(), self.client.clone()); + let models_service = ModelsDevService::new(&self.redis, &self.client); let models = models_service .list_models(ModelsDevServiceProvider::Anthropic) .await?; diff --git a/server/src/provider/openai.rs b/server/src/provider/openai.rs index b6c0a68..d865554 100644 --- a/server/src/provider/openai.rs +++ b/server/src/provider/openai.rs @@ -19,7 +19,7 @@ const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1"; #[derive(Debug, Clone)] pub struct OpenAIProvider { client: reqwest::Client, - redis: fred::prelude::Client, + redis: fred::prelude::Pool, api_key: String, base_url: String, } @@ -27,7 +27,7 @@ pub struct OpenAIProvider { impl OpenAIProvider { pub fn new( http_client: &reqwest::Client, - redis: &fred::prelude::Client, + redis: &fred::prelude::Pool, api_key: &str, base_url: Option<&str>, ) -> Self { @@ -291,7 +291,7 @@ impl LlmApiProvider for OpenAIProvider { } async fn list_models(&self) -> Result, LlmError> { - let models_service = ModelsDevService::new(self.redis.clone(), self.client.clone()); + let models_service = ModelsDevService::new(&self.redis, &self.client); let models = models_service .list_models({ match self.base_url.as_str() { diff --git a/server/src/provider_models.rs b/server/src/provider_models.rs index 58e41f1..416c259 100644 --- a/server/src/provider_models.rs +++ b/server/src/provider_models.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use enum_iterator::{all, Sequence}; -use fred::prelude::*; +use fred::prelude::{HashesInterface, KeysInterface}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -56,13 +56,16 @@ pub enum ModalityType { /// Service to fetch and cache LLM model list from https://models.dev pub struct ModelsDevService { - redis: Client, + redis: fred::prelude::Pool, http_client: reqwest::Client, } impl ModelsDevService { - pub fn new(redis: Client, http_client: reqwest::Client) -> Self { - Self { redis, http_client } + pub fn new(redis: &fred::prelude::Pool, http_client: &reqwest::Client) -> Self { + Self { + redis: redis.clone(), + http_client: http_client.clone(), + } } pub async fn list_models( @@ -108,7 +111,7 @@ impl ModelsDevService { cache.insert(provider_str.to_owned(), parsed_models_str); } - let pipeline = self.redis.pipeline(); + let pipeline = self.redis.next().pipeline(); let _: () = pipeline.hset(CACHE_KEY, cache).await?; let _: () = pipeline.expire(CACHE_KEY, CACHE_TTL, None).await?; let _: () = pipeline.all().await?; diff --git a/server/src/redis.rs b/server/src/redis.rs index ecd1c07..c79750c 100644 --- a/server/src/redis.rs +++ b/server/src/redis.rs @@ -1,44 +1,11 @@ -use std::{ops::Deref, time::Duration}; +use std::time::Duration; -use fred::prelude::{Builder, Client, ClientLike, Config, Pool, TcpConfig}; -use rocket::{ - fairing::AdHoc, - http::Status, - request::{FromRequest, Outcome}, -}; -use rocket_okapi::OpenApiFromRequest; +use fred::prelude::{Builder, ClientLike, Config, Pool, TcpConfig}; +use rocket::fairing::AdHoc; use crate::config::get_app_config; -/// Redis connection, available as a request guard. When used as a request parameter, -/// it will retrieve a connection from the managed Redis pool. -#[derive(OpenApiFromRequest)] -pub struct RedisClient(Client); -impl Deref for RedisClient { - type Target = Client; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/// Retrieve a connection from the managed Redis pool. Responds with an -/// internal server error if Redis not initialized. -#[rocket::async_trait] -impl<'r> FromRequest<'r> for RedisClient { - type Error = String; - - async fn from_request(req: &'r rocket::Request<'_>) -> Outcome { - let Some(pool) = req.rocket().state::() else { - return Outcome::Error(( - Status::InternalServerError, - "Redis not initialized".to_owned(), - )); - }; - Outcome::Success(RedisClient(pool.next().clone())) - } -} - -/// Fairing that sets up and initializes the Redis database +/// Fairing that sets up and initializes the Redis connection pool. pub fn setup_redis() -> AdHoc { AdHoc::on_ignite("Redis", |rocket| async { rocket From b85b9a765deb1592a3ce940f0bd77d92a06be728 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 22 Aug 2025 16:57:41 -0400 Subject: [PATCH 06/46] server: can list ongoing chat streams and cancel streams --- server/src/api/chat.rs | 88 ++++++++---- server/src/db/models/chat.rs | 5 + server/src/provider.rs | 2 +- server/src/redis.rs | 8 +- server/src/stream.rs | 12 ++ server/src/stream/llm_writer.rs | 107 +++++++++------ server/src/stream/reader.rs | 41 +++++- server/src/utils.rs | 2 - server/src/utils/stored_stream.rs | 221 ------------------------------ 9 files changed, 190 insertions(+), 296 deletions(-) delete mode 100644 server/src/utils/stored_stream.rs diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index b6fe03b..6596913 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -1,6 +1,5 @@ use std::{borrow::Cow, pin::Pin}; -use fred::prelude::KeysInterface; use rocket::{ futures::{stream, Stream, StreamExt}, get, post, @@ -34,7 +33,30 @@ use crate::{ }; pub fn get_routes(settings: &OpenApiSettings) -> (Vec, OpenApi) { - openapi_get_routes_spec![settings: send_chat_stream_v2, connect_to_chat_stream] + openapi_get_routes_spec![ + settings: get_chat_streams, + send_chat_stream, + connect_to_chat_stream, + cancel_chat_stream, + ] +} + +#[derive(Debug, JsonSchema, serde::Serialize)] +pub struct GetChatStreamsResponse { + streams: Vec, +} + +/// # Get chat streams +/// Get the ongoing chat response streams +#[openapi(tag = "Chat")] +#[get("/streams")] +pub async fn get_chat_streams( + user_id: ChatRsUserId, + redis: &State, +) -> Result, ApiError> { + let stream_reader = SseStreamReader::new(&redis); + let keys = stream_reader.get_chat_streams(&user_id).await?; + Ok(Json(GetChatStreamsResponse { streams: keys })) } #[derive(JsonSchema, serde::Deserialize)] @@ -54,7 +76,7 @@ pub struct SendChatInput<'a> { /// has started, use the `//stream` endpoint to connect to the SSE stream. #[openapi(tag = "Chat")] #[post("/", data = "")] -pub async fn send_chat_stream_v2( +pub async fn send_chat_stream( user_id: ChatRsUserId, db_pool: &State, mut db: DbConnection, @@ -64,13 +86,14 @@ pub async fn send_chat_stream_v2( session_id: Uuid, mut input: Json>, ) -> Result { + let mut stream_writer = LlmStreamWriter::new(&redis, &user_id, &session_id); + // Check that we aren't already streaming a response for this session - let stream_key = format!("user:{}:chat:{}", user_id.0, session_id); - if redis.exists(&stream_key).await? { + if stream_writer.exists().await? { return Err(LlmError::AlreadyStreaming)?; } - // Check session exists and user is owner, get message history + // Get session and message history let (session, mut current_messages) = ChatDbService::new(&mut db) .get_session_with_messages(&user_id, &session_id) .await?; @@ -128,17 +151,13 @@ pub async fn send_chat_stream_v2( .tool_config .is_none_or(|config| config != tool_input) { + let meta = ChatRsSessionMeta::with_tool_config(Some(tool_input)); + let data = UpdateChatRsSession { + meta: Some(&meta), + ..Default::default() + }; ChatDbService::new(&mut db) - .update_session( - &user_id, - &session_id, - UpdateChatRsSession { - meta: Some(&ChatRsSessionMeta { - tool_config: Some(tool_input), - }), - ..Default::default() - }, - ) + .update_session(&user_id, &session_id, data) .await?; } } @@ -150,11 +169,11 @@ pub async fn send_chat_stream_v2( let provider_id = input.provider_id; let provider_options = input.provider_options.clone(); - // Spawn a task to stream the response to Redis and save it to the database on completion - let stream_writer = LlmStreamWriter::new(&redis); + // Create the Redis stream, then spawn a task to stream the response + // and save it to the database on completion + stream_writer.start().await?; tokio::spawn(async move { - let (content, tool_calls, usage, _) = - stream_writer.process_stream(&stream_key, stream).await; + let (content, tool_calls, usage, _) = stream_writer.process(stream).await; if let Err(e) = ChatDbService::new(&mut db) .save_message(NewChatRsMessage { session_id: &session_id, @@ -175,8 +194,8 @@ pub async fn send_chat_stream_v2( { rocket::warn!("Failed to save assistant response: {}", e); } - - // TODO delete stream in Redis + stream_writer.finish().await.ok(); + drop(stream_writer); }); Ok("Stream started".into()) @@ -191,21 +210,21 @@ pub async fn connect_to_chat_stream( redis: &State, session_id: Uuid, ) -> Result + Send>>>, ApiError> { - let stream_key = format!("user:{}:chat:{}", user_id.0, session_id); let stream_reader = SseStreamReader::new(&redis); // Get all previous events from the Redis stream, and return them if we're already at the end of the stream - let (prev_events, last_event_id, is_end) = stream_reader.get_prev_events(&stream_key).await?; + let (prev_events, last_event_id, is_end) = + stream_reader.get_prev_events(&user_id, &session_id).await?; let prev_events_stream = stream::iter(prev_events); if is_end { return Ok(EventStream::from(prev_events_stream.boxed())); } - // Spawn a task to receive new events from Redis and add them to the channel + // Spawn a task to receive new events from Redis and add them to this channel let (tx, rx) = tokio::sync::mpsc::channel::(50); tokio::spawn(async move { stream_reader - .stream_events(&stream_key, &last_event_id, &tx) + .stream(&user_id, &session_id, &last_event_id, &tx) .await; drop(tx); }); @@ -214,3 +233,20 @@ pub async fn connect_to_chat_stream( let stream = prev_events_stream.chain(ReceiverStream::new(rx)).boxed(); Ok(EventStream::from(stream)) } + +/// # Cancel chat stream +/// Cancel an ongoing chat stream +#[openapi(tag = "Chat")] +#[post("//cancel")] +pub async fn cancel_chat_stream( + user_id: ChatRsUserId, + redis: &State, + session_id: Uuid, +) -> Result<(), ApiError> { + let stream_writer = LlmStreamWriter::new(&redis, &user_id, &session_id); + if !stream_writer.exists().await? { + return Err(LlmError::StreamNotFound)?; + } + stream_writer.cancel().await?; + Ok(()) +} diff --git a/server/src/db/models/chat.rs b/server/src/db/models/chat.rs index a32ccc5..6abc024 100644 --- a/server/src/db/models/chat.rs +++ b/server/src/db/models/chat.rs @@ -33,6 +33,11 @@ pub struct ChatRsSessionMeta { #[serde(skip_serializing_if = "Option::is_none")] pub tool_config: Option, } +impl ChatRsSessionMeta { + pub fn with_tool_config(tool_config: Option) -> Self { + Self { tool_config } + } +} #[derive(Insertable)] #[diesel(table_name = super::schema::chat_sessions)] diff --git a/server/src/provider.rs b/server/src/provider.rs index ba9bc30..d404b31 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -39,7 +39,7 @@ pub enum LlmError { UnsupportedProvider, #[error("Already streaming a response for this session")] AlreadyStreaming, - #[error("No ongoing stream for this session")] + #[error("No stream found, or the stream was cancelled")] StreamNotFound, #[error("Missing event in stream")] NoStreamEvent, diff --git a/server/src/redis.rs b/server/src/redis.rs index c79750c..656f497 100644 --- a/server/src/redis.rs +++ b/server/src/redis.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use fred::prelude::{Builder, ClientLike, Config, Pool, TcpConfig}; +use fred::prelude::{Builder, ClientLike, Config, Pool, ReconnectPolicy, TcpConfig}; use rocket::fairing::AdHoc; use crate::config::get_app_config; @@ -18,11 +18,17 @@ pub fn setup_redis() -> AdHoc { let pool = Builder::from_config(config) .with_connection_config(|config| { config.connection_timeout = Duration::from_secs(4); + config.internal_command_timeout = Duration::from_secs(6); + config.max_command_attempts = 2; config.tcp = TcpConfig { nodelay: Some(true), ..Default::default() }; }) + .set_policy(ReconnectPolicy::new_linear(5, 4000, 1000)) + .with_performance_config(|config| { + config.default_command_timeout = Duration::from_secs(6); + }) .build_pool(app_config.redis_pool.unwrap_or(4)) .expect("Failed to build Redis pool"); pool.init().await.expect("Failed to connect to Redis"); diff --git a/server/src/stream.rs b/server/src/stream.rs index df4e981..6881fc9 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -3,3 +3,15 @@ mod reader; pub use llm_writer::*; pub use reader::*; + +use uuid::Uuid; + +/// Get the key prefix for the user's chat streams in Redis +fn get_chat_stream_prefix(user_id: &Uuid) -> String { + format!("user:{}:chat", user_id) +} + +/// Get the key of the chat stream in Redis for the given user and session ID +fn get_chat_stream_key(user_id: &Uuid, session_id: &Uuid) -> String { + format!("{}:{}", get_chat_stream_prefix(user_id), session_id) +} diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 1c18be2..b4ff2c4 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -6,10 +6,12 @@ use std::{ use fred::prelude::{KeysInterface, StreamsInterface}; use rocket::futures::StreamExt; use serde::Serialize; +use uuid::Uuid; use crate::{ db::models::ChatRsToolCall, provider::{LlmApiStream, LlmError, LlmUsage}, + stream::get_chat_stream_key, }; /// Interval at which chunks are flushed to Redis. @@ -19,11 +21,12 @@ const MAX_CHUNK_SIZE: usize = 1000; /// Expiration in seconds set on the Redis stream (normally, the Redis stream will be deleted before this) const STREAM_EXPIRE: i64 = 30; -/// Utility for processing an incoming LLM response stream and intermittently -/// writing chunks to a Redis stream. +/// Utility for processing an incoming LLM response stream and writing to a Redis stream. #[derive(Debug)] pub struct LlmStreamWriter { redis: fred::prelude::Pool, + /// The key of the Redis stream. + key: String, /// The current chunk of data being processed. current_chunk: ChunkState, /// Accumulated text response from the assistant. @@ -52,6 +55,7 @@ enum RedisStreamChunk { Text(String), ToolCall(String), Error(String), + Cancel, End, } impl From for HashMap { @@ -63,9 +67,10 @@ impl From for HashMap { } impl LlmStreamWriter { - pub fn new(redis: &fred::prelude::Pool) -> Self { + pub fn new(redis: &fred::prelude::Pool, user_id: &Uuid, session_id: &Uuid) -> Self { LlmStreamWriter { redis: redis.clone(), + key: get_chat_stream_key(user_id, session_id), current_chunk: ChunkState::default(), complete_text: None, tool_calls: None, @@ -74,11 +79,24 @@ impl LlmStreamWriter { } } + /// Check if the Redis stream already exists. + pub async fn exists(&self) -> Result { + self.redis.exists(&self.key).await + } + + /// Create the Redis stream and write a `start` entry. + pub async fn start(&self) -> Result<(), fred::prelude::Error> { + let entry: HashMap = RedisStreamChunk::Start.into(); + let pipeline = self.redis.next().pipeline(); + let _: () = pipeline.xadd(&self.key, false, None, "*", entry).await?; + let _: () = pipeline.expire(&self.key, STREAM_EXPIRE, None).await?; + pipeline.all().await + } + /// Process the incoming stream from the LLM provider, intermittently /// flushing chunks to a Redis stream, and return the final accumulated response. - pub async fn process_stream( - mut self, - stream_key: &str, + pub async fn process( + &mut self, mut stream: LlmApiStream, ) -> ( Option, @@ -86,10 +104,6 @@ impl LlmStreamWriter { Option, Option>, ) { - if let Err(e) = self.notify_start_of_redis_stream(&stream_key).await { - self.errors.get_or_insert_default().push(LlmError::Redis(e)); - }; - let mut last_flush_time = Instant::now(); while let Some(chunk) = stream.next().await { match chunk { @@ -110,16 +124,41 @@ impl LlmStreamWriter { } if self.should_flush(&last_flush_time) { - self.flush_and_reset(&stream_key).await; + if let Err(err) = self.flush_chunk().await { + if matches!(err, LlmError::StreamNotFound) { + self.errors.get_or_insert_default().push(err); + break; // stream was deleted/cancelled + } + self.process_error(err); + } last_flush_time = Instant::now(); } } - if let Err(e) = self.notify_end_of_redis_stream(&stream_key).await { - self.errors.get_or_insert_default().push(LlmError::Redis(e)); - }; + let complete_text = self.complete_text.take(); + let tool_calls = self.tool_calls.take(); + let usage = self.usage.take(); + let errors = self.errors.take(); + (complete_text, tool_calls, usage, errors) + } - (self.complete_text, self.tool_calls, self.usage, self.errors) + /// Cancel stream by adding a `cancel` event to the stream and then deleting it from Redis. + pub async fn cancel(&self) -> Result<(), fred::prelude::Error> { + let entry: HashMap = RedisStreamChunk::Cancel.into(); + let pipeline = self.redis.next().pipeline(); + let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; + let _: () = pipeline.del(&self.key).await?; + pipeline.all().await + } + + /// Add an `end` event to notify clients that the stream has ended, and then + /// delete the stream from Redis. + pub async fn finish(&self) -> Result<(), fred::prelude::Error> { + let entry: HashMap = RedisStreamChunk::End.into(); + let pipeline = self.redis.next().pipeline(); + let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; + let _: () = pipeline.del(&self.key).await?; + pipeline.all().await } fn process_text(&mut self, text: &str) { @@ -168,7 +207,7 @@ impl LlmStreamWriter { return false; } - async fn flush_and_reset(&mut self, stream_key: &str) { + async fn flush_chunk(&mut self) -> Result<(), LlmError> { let chunk_state = std::mem::take(&mut self.current_chunk); let mut chunks: Vec = Vec::with_capacity(2); @@ -185,35 +224,25 @@ impl LlmStreamWriter { } let entries = chunks.into_iter().map(|chunk| chunk.into()).collect(); - let _ = self.add_to_redis_stream(stream_key, entries).await; - } - - async fn notify_start_of_redis_stream( - &mut self, - stream_key: &str, - ) -> Result<(), fred::prelude::Error> { - let entries = vec![RedisStreamChunk::Start.into()]; - self.add_to_redis_stream(stream_key, entries).await - } - - async fn notify_end_of_redis_stream( - &mut self, - stream_key: &str, - ) -> Result<(), fred::prelude::Error> { - let entries = vec![RedisStreamChunk::End.into()]; - self.add_to_redis_stream(stream_key, entries).await + self.add_to_redis_stream(entries).await } async fn add_to_redis_stream( - &mut self, - stream_key: &str, + &self, entries: Vec>, - ) -> Result<(), fred::prelude::Error> { + ) -> Result<(), LlmError> { let pipeline = self.redis.next().pipeline(); for entry in entries { - let _: () = pipeline.xadd(stream_key, false, None, "*", entry).await?; + let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; + } + let _: () = pipeline.expire(&self.key, STREAM_EXPIRE, None).await?; + let res: Vec = pipeline.all().await?; + + // Check for `nil` responses indicating the stream has been deleted/cancelled + if res.iter().any(|r| matches!(r, fred::prelude::Value::Null)) { + Err(LlmError::StreamNotFound) + } else { + Ok(()) } - let _: () = pipeline.expire(stream_key, STREAM_EXPIRE, None).await?; - pipeline.all().await } } diff --git a/server/src/stream/reader.rs b/server/src/stream/reader.rs index 4bc2830..35aeeaf 100644 --- a/server/src/stream/reader.rs +++ b/server/src/stream/reader.rs @@ -1,10 +1,17 @@ use std::collections::HashMap; -use fred::prelude::StreamsInterface; +use fred::{ + prelude::{KeysInterface, StreamsInterface}, + types::scan::ScanType, +}; use rocket::response::stream::Event; use tokio::sync::mpsc; +use uuid::Uuid; -use crate::provider::LlmError; +use crate::{ + provider::LlmError, + stream::{get_chat_stream_key, get_chat_stream_prefix}, +}; /// Timeout for the blocking `xread` command. const XREAD_BLOCK_TIMEOUT: u64 = 10_000; // 10 seconds @@ -21,13 +28,28 @@ impl SseStreamReader { } } + /// Get the ongoing chat streams for a user. + pub async fn get_chat_streams(&self, user_id: &Uuid) -> Result, LlmError> { + let pattern = format!("{}:*", get_chat_stream_prefix(user_id)); + let keys = self + .redis + .scan_page("0", &pattern, Some(20), Some(ScanType::Stream)) + .await?; + Ok(keys) + } + /// Retrieve the previous events from the given Redis stream. /// Returns a tuple containing the previous events, the last event ID, and a boolean /// indicating if the stream has already ended. - pub async fn get_prev_events(&self, key: &str) -> Result<(Vec, String, bool), LlmError> { + pub async fn get_prev_events( + &self, + user_id: &Uuid, + session_id: &Uuid, + ) -> Result<(Vec, String, bool), LlmError> { + let key = get_chat_stream_key(user_id, session_id); let (_, prev_events): (String, Vec<(String, HashMap)>) = self .redis - .xread::>, _, _>(None, None, key, "0-0") + .xread::>, _, _>(None, None, &key, "0-0") .await? .and_then(|mut streams| streams.pop()) // should only be 1 stream since we're sending 1 key in the command .ok_or(LlmError::StreamNotFound)?; @@ -44,10 +66,17 @@ impl SseStreamReader { } /// Stream the events from the given Redis stream using a blocking `xread` command. - pub async fn stream_events(&self, key: &str, last_event_id: &str, tx: &mpsc::Sender) { + pub async fn stream( + &self, + user_id: &Uuid, + session_id: &Uuid, + last_event_id: &str, + tx: &mpsc::Sender, + ) { + let key = get_chat_stream_key(user_id, session_id); let mut last_event_id = last_event_id.to_owned(); loop { - match self.get_next_event(key, &mut last_event_id, tx).await { + match self.get_next_event(&key, &mut last_event_id, tx).await { Ok((id, data, is_end)) => { let event = convert_redis_event_to_sse((id, data)); if let Err(_) = tx.send(event).await { diff --git a/server/src/utils.rs b/server/src/utils.rs index 32b9c14..bd85450 100644 --- a/server/src/utils.rs +++ b/server/src/utils.rs @@ -3,11 +3,9 @@ mod full_text_search; mod generate_title; mod json_logging; mod sender_with_logging; -mod stored_stream; pub use encryption::*; pub use full_text_search::*; pub use generate_title::*; pub use json_logging::*; pub use sender_with_logging::*; -pub use stored_stream::*; diff --git a/server/src/utils/stored_stream.rs b/server/src/utils/stored_stream.rs deleted file mode 100644 index 35f54a1..0000000 --- a/server/src/utils/stored_stream.rs +++ /dev/null @@ -1,221 +0,0 @@ -use std::{ - pin::Pin, - task::{Context, Poll}, - time::{Duration, Instant}, -}; - -use fred::{ - prelude::{Client, KeysInterface}, - types::Expiration, -}; -use rocket::futures::Stream; -use uuid::Uuid; - -use crate::{ - db::{ - models::{ - AssistantMeta, ChatRsMessageMeta, ChatRsMessageRole, ChatRsToolCall, NewChatRsMessage, - }, - services::ChatDbService, - DbConnection, DbPool, - }, - provider::{LlmApiProviderSharedOptions, LlmUsage}, -}; - -/// A wrapper around the chat assistant stream that intermittently caches output in Redis, and -/// saves the assistant's response to the database at the end of the stream. -pub struct StoredChatRsStream< - S: Stream>, -> { - inner: Pin>, - provider_id: i32, - provider_options: Option, - redis_client: Client, - db_pool: DbPool, - session_id: Uuid, - buffer: Vec, - tool_calls: Option>, - input_tokens: u32, - output_tokens: u32, - cost: Option, - last_cache_time: Instant, -} - -/// The prefix for the cache key used for streaming chat responses. -pub const CHAT_CACHE_KEY_PREFIX: &str = "chat_session:"; -const CHAT_CACHE_INTERVAL: Duration = Duration::from_secs(1); // cache the response every second - -impl StoredChatRsStream -where - S: Stream>, -{ - pub fn new( - stream: S, - provider_id: i32, - provider_options: LlmApiProviderSharedOptions, - db_pool: DbPool, - redis_client: Client, - session_id: Option, - ) -> Self { - Self { - inner: Box::pin(stream), - provider_id, - provider_options: Some(provider_options), - db_pool, - redis_client, - session_id: session_id.unwrap_or_else(|| Uuid::new_v4()), - buffer: Vec::new(), - tool_calls: None, - input_tokens: 0, - output_tokens: 0, - cost: None, - last_cache_time: Instant::now(), - } - } - - pub fn session_id(&self) -> &Uuid { - &self.session_id - } - - fn save_response(&mut self, interrupted: Option) { - let redis_client = self.redis_client.clone(); - let pool = self.db_pool.clone(); - let session_id = self.session_id.clone(); - let provider_id = self.provider_id.clone(); - let provider_options = self.provider_options.take(); - let content = self.buffer.join(""); - let tool_calls = self.tool_calls.take(); - let usage = Some(LlmUsage { - input_tokens: Some(self.input_tokens), - output_tokens: Some(self.output_tokens), - cost: self.cost, - }); - self.buffer.clear(); - - tokio::spawn(async move { - let Ok(db) = pool.get().await else { - rocket::error!("Couldn't get connection while saving chat response"); - return; - }; - if let Err(e) = ChatDbService::new(&mut DbConnection(db)) - .save_message(NewChatRsMessage { - role: ChatRsMessageRole::Assistant, - content: &content, - session_id: &session_id, - meta: ChatRsMessageMeta { - assistant: Some(AssistantMeta { - provider_id, - provider_options, - partial: interrupted, - usage, - tool_calls, - }), - ..Default::default() - }, - }) - .await - { - rocket::error!("Failed saving chat response, session {}: {}", session_id, e); - } else { - rocket::info!("Saved chat response, session {}", session_id); - } - - let key = format!("{}{}", CHAT_CACHE_KEY_PREFIX, session_id); - let _ = redis_client.del::<(), _>(&key).await; - }); - } - - fn should_cache(&self) -> bool { - self.last_cache_time.elapsed() >= CHAT_CACHE_INTERVAL - } -} - -impl Stream for StoredChatRsStream -where - S: Stream>, -{ - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.inner.as_mut().poll_next(cx) { - Poll::Ready(Some(Ok(chunk))) => { - // Add text to buffer - if let Some(text) = &chunk.text { - self.buffer.push(text.clone()); - } - - // Record tool calls - if let Some(tool_calls) = chunk.tool_calls { - self.tool_calls.get_or_insert_default().extend(tool_calls); - } - - // Record usage - if let Some(usage) = chunk.usage { - if let Some(input_tokens) = usage.input_tokens { - self.input_tokens = input_tokens; - } - if let Some(output_tokens) = usage.output_tokens { - self.output_tokens = output_tokens; - } - if let Some(cost) = usage.cost { - self.cost = Some(cost); - } - } - - // Check if we should cache - if self.should_cache() { - let redis_client = self.redis_client.clone(); - let session_id = self.session_id.clone(); - let content = self.buffer.join(""); - - // Spawn async task to cache - tokio::spawn(async move { - let key = format!("{}{}", CHAT_CACHE_KEY_PREFIX, session_id); - rocket::debug!("Caching chat session {}", session_id); - if let Err(e) = redis_client - .set::<(), _, _>( - &key, - &content, - Some(Expiration::EX(3600)), - None, - false, - ) - .await - { - rocket::error!("Redis cache error: {}", e); - } - }); - - self.last_cache_time = Instant::now(); - } - - if let Some(text) = chunk.text { - Poll::Ready(Some(Ok(text))) - } else { - self.poll_next(cx) - } - } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(None) => { - // Stream ended, flush final buffer - if !self.buffer.is_empty() || self.tool_calls.is_some() { - self.save_response(None); - } - Poll::Ready(None) - } - Poll::Pending => Poll::Pending, - } - } -} - -impl Drop for StoredChatRsStream -where - S: Stream>, -{ - /// Stream was interrupted. Save response and mark as interrupted - fn drop(&mut self) { - if !self.buffer.is_empty() || self.tool_calls.is_some() { - self.save_response(Some(true)); - } - } -} From 60d2f2d897f2d9d515620f150245e6af8ca9d7a0 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 22 Aug 2025 19:59:03 -0400 Subject: [PATCH 07/46] server: move db saving logic to stream utility --- server/src/api/chat.rs | 59 +++++---------- server/src/db/models/chat.rs | 13 +++- server/src/provider.rs | 2 + server/src/stream/llm_writer.rs | 122 +++++++++++++++++++++++--------- 4 files changed, 121 insertions(+), 75 deletions(-) diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index 6596913..b012a7b 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -19,8 +19,8 @@ use crate::{ auth::ChatRsUserId, db::{ models::{ - AssistantMeta, ChatRsMessageMeta, ChatRsMessageRole, ChatRsSessionMeta, - NewChatRsMessage, UpdateChatRsSession, + ChatRsMessageMeta, ChatRsMessageRole, ChatRsSessionMeta, NewChatRsMessage, + UpdateChatRsSession, }, services::{ChatDbService, ProviderDbService, ToolDbService}, DbConnection, DbPool, @@ -65,8 +65,8 @@ pub struct SendChatInput<'a> { message: Option>, /// The ID of the provider to chat with provider_id: i32, - /// Provider options - provider_options: LlmApiProviderSharedOptions, + /// Configuration for the provider + options: LlmApiProviderSharedOptions, /// Configuration of tools available to the assistant tools: Option, } @@ -94,7 +94,7 @@ pub async fn send_chat_stream( } // Get session and message history - let (session, mut current_messages) = ChatDbService::new(&mut db) + let (session, mut messages) = ChatDbService::new(&mut db) .get_session_with_messages(&user_id, &session_id) .await?; @@ -114,16 +114,15 @@ pub async fn send_chat_stream( )?; // Get the user's chosen tools - let mut llm_tools = None; + let mut tools = None; if let Some(tool_input) = input.tools.as_ref() { let mut tool_db_service = ToolDbService::new(&mut db); - let tools = get_llm_tools_from_input(&user_id, tool_input, &mut tool_db_service).await?; - llm_tools = Some(tools); + tools = Some(get_llm_tools_from_input(&user_id, tool_input, &mut tool_db_service).await?); } // Generate session title if needed, and save user message to database if let Some(user_message) = &input.message { - if current_messages.is_empty() && session.title == DEFAULT_SESSION_TITLE { + if messages.is_empty() && session.title == DEFAULT_SESSION_TITLE { generate_title( &user_id, &session_id, @@ -141,7 +140,7 @@ pub async fn send_chat_stream( meta: ChatRsMessageMeta::default(), }) .await?; - current_messages.push(new_message); + messages.push(new_message); } // Update session metadata if needed @@ -151,7 +150,7 @@ pub async fn send_chat_stream( .tool_config .is_none_or(|config| config != tool_input) { - let meta = ChatRsSessionMeta::with_tool_config(Some(tool_input)); + let meta = ChatRsSessionMeta::new(Some(tool_input)); let data = UpdateChatRsSession { meta: Some(&meta), ..Default::default() @@ -164,41 +163,21 @@ pub async fn send_chat_stream( // Get the provider's stream response let stream = provider_api - .chat_stream(current_messages, llm_tools, &input.provider_options) + .chat_stream(messages, tools, &input.options) .await?; let provider_id = input.provider_id; - let provider_options = input.provider_options.clone(); + let provider_options = input.options.clone(); - // Create the Redis stream, then spawn a task to stream the response - // and save it to the database on completion - stream_writer.start().await?; + // Create the Redis stream, then spawn a task to stream and save the response + stream_writer.create().await?; tokio::spawn(async move { - let (content, tool_calls, usage, _) = stream_writer.process(stream).await; - if let Err(e) = ChatDbService::new(&mut db) - .save_message(NewChatRsMessage { - session_id: &session_id, - role: ChatRsMessageRole::Assistant, - content: &content.unwrap_or_default(), - meta: ChatRsMessageMeta { - assistant: Some(AssistantMeta { - provider_id, - provider_options: Some(provider_options), - tool_calls, - usage, - ..Default::default() - }), - ..Default::default() - }, - }) - .await - { - rocket::warn!("Failed to save assistant response: {}", e); - } - stream_writer.finish().await.ok(); - drop(stream_writer); + let mut chat_db_service = ChatDbService::new(&mut db); + stream_writer + .process(stream, &mut chat_db_service, provider_id, provider_options) + .await; }); - Ok("Stream started".into()) + Ok(format!("Stream started at /api/chat/{}/stream", session_id)) } /// # Connect to chat stream diff --git a/server/src/db/models/chat.rs b/server/src/db/models/chat.rs index 6abc024..60080c9 100644 --- a/server/src/db/models/chat.rs +++ b/server/src/db/models/chat.rs @@ -34,7 +34,7 @@ pub struct ChatRsSessionMeta { pub tool_config: Option, } impl ChatRsSessionMeta { - pub fn with_tool_config(tool_config: Option) -> Self { + pub fn new(tool_config: Option) -> Self { Self { tool_config } } } @@ -84,6 +84,14 @@ pub struct ChatRsMessageMeta { #[serde(skip_serializing_if = "Option::is_none")] pub tool_call: Option, } +impl ChatRsMessageMeta { + pub fn new_assistant(assistant: AssistantMeta) -> Self { + Self { + assistant: Some(assistant), + tool_call: None, + } + } +} #[derive(Debug, Default, JsonSchema, Serialize, Deserialize)] pub struct AssistantMeta { @@ -98,6 +106,9 @@ pub struct AssistantMeta { /// Provider usage information #[serde(skip_serializing_if = "Option::is_none")] pub usage: Option, + /// Errors encountered during message generation + #[serde(skip_serializing_if = "Option::is_none")] + pub errors: Option>, /// Whether this is a partial and/or interrupted message #[serde(skip_serializing_if = "Option::is_none")] pub partial: Option, diff --git a/server/src/provider.rs b/server/src/provider.rs index d404b31..09ff351 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -51,6 +51,8 @@ pub enum LlmError { DecryptionError, #[error("Redis error: {0}")] Redis(#[from] fred::error::Error), + #[error("Failed to save message: {0}")] + Database(#[from] diesel::result::Error), } /// A streaming chunk of data from the LLM provider diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index b4ff2c4..3dd6106 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -9,8 +9,13 @@ use serde::Serialize; use uuid::Uuid; use crate::{ - db::models::ChatRsToolCall, - provider::{LlmApiStream, LlmError, LlmUsage}, + db::{ + models::{ + AssistantMeta, ChatRsMessageMeta, ChatRsMessageRole, ChatRsToolCall, NewChatRsMessage, + }, + services::ChatDbService, + }, + provider::{LlmApiProviderSharedOptions, LlmApiStream, LlmError, LlmUsage}, stream::get_chat_stream_key, }; @@ -27,6 +32,8 @@ pub struct LlmStreamWriter { redis: fred::prelude::Pool, /// The key of the Redis stream. key: String, + /// The chat session ID associated with the stream. + session_id: Uuid, /// The current chunk of data being processed. current_chunk: ChunkState, /// Accumulated text response from the assistant. @@ -71,6 +78,7 @@ impl LlmStreamWriter { LlmStreamWriter { redis: redis.clone(), key: get_chat_stream_key(user_id, session_id), + session_id: session_id.to_owned(), current_chunk: ChunkState::default(), complete_text: None, tool_calls: None, @@ -85,7 +93,7 @@ impl LlmStreamWriter { } /// Create the Redis stream and write a `start` entry. - pub async fn start(&self) -> Result<(), fred::prelude::Error> { + pub async fn create(&self) -> Result<(), fred::prelude::Error> { let entry: HashMap = RedisStreamChunk::Start.into(); let pipeline = self.redis.next().pipeline(); let _: () = pipeline.xadd(&self.key, false, None, "*", entry).await?; @@ -93,18 +101,17 @@ impl LlmStreamWriter { pipeline.all().await } - /// Process the incoming stream from the LLM provider, intermittently - /// flushing chunks to a Redis stream, and return the final accumulated response. + /// Process the incoming stream from the LLM provider, intermittently flushing + /// chunks to a Redis stream, and saving the final accumulated response to the database. pub async fn process( &mut self, mut stream: LlmApiStream, - ) -> ( - Option, - Option>, - Option, - Option>, + db: &mut ChatDbService<'_>, + provider_id: i32, + provider_options: LlmApiProviderSharedOptions, ) { let mut last_flush_time = Instant::now(); + let mut cancelled = false; while let Some(chunk) = stream.next().await { match chunk { Ok(chunk) => { @@ -127,7 +134,8 @@ impl LlmStreamWriter { if let Err(err) = self.flush_chunk().await { if matches!(err, LlmError::StreamNotFound) { self.errors.get_or_insert_default().push(err); - break; // stream was deleted/cancelled + cancelled = true; + break; } self.process_error(err); } @@ -135,30 +143,19 @@ impl LlmStreamWriter { } } - let complete_text = self.complete_text.take(); - let tool_calls = self.tool_calls.take(); - let usage = self.usage.take(); - let errors = self.errors.take(); - (complete_text, tool_calls, usage, errors) - } - - /// Cancel stream by adding a `cancel` event to the stream and then deleting it from Redis. - pub async fn cancel(&self) -> Result<(), fred::prelude::Error> { - let entry: HashMap = RedisStreamChunk::Cancel.into(); - let pipeline = self.redis.next().pipeline(); - let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; - let _: () = pipeline.del(&self.key).await?; - pipeline.all().await - } + if let Err(e) = self + .save_to_db(db, provider_id, provider_options, cancelled) + .await + { + if !cancelled { + self.current_chunk.error = Some(e.to_string()); + self.flush_chunk().await.ok(); + } + } - /// Add an `end` event to notify clients that the stream has ended, and then - /// delete the stream from Redis. - pub async fn finish(&self) -> Result<(), fred::prelude::Error> { - let entry: HashMap = RedisStreamChunk::End.into(); - let pipeline = self.redis.next().pipeline(); - let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; - let _: () = pipeline.del(&self.key).await?; - pipeline.all().await + if !cancelled { + self.finish().await.ok(); + } } fn process_text(&mut self, text: &str) { @@ -227,6 +224,7 @@ impl LlmStreamWriter { self.add_to_redis_stream(entries).await } + /// Adds a new entry to the Redis stream. Returns an error if the stream has been deleted or cancelled. async fn add_to_redis_stream( &self, entries: Vec>, @@ -245,4 +243,60 @@ impl LlmStreamWriter { Ok(()) } } + + /// Add an `end` event to notify clients that the stream has ended, and then + /// delete the stream from Redis. + async fn finish(&self) -> Result<(), fred::prelude::Error> { + let entry: HashMap = RedisStreamChunk::End.into(); + let pipeline = self.redis.next().pipeline(); + let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; + let _: () = pipeline.del(&self.key).await?; + pipeline.all().await + } + + /// Cancel stream by adding a `cancel` event to the stream and then deleting it from Redis. + pub async fn cancel(&self) -> Result<(), fred::prelude::Error> { + let entry: HashMap = RedisStreamChunk::Cancel.into(); + let pipeline = self.redis.next().pipeline(); + let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; + let _: () = pipeline.del(&self.key).await?; + pipeline.all().await + } + + /// Saves the accumulated response to the database + async fn save_to_db( + &mut self, + db_service: &mut ChatDbService<'_>, + provider_id: i32, + provider_options: LlmApiProviderSharedOptions, + cancelled: bool, + ) -> Result<(), LlmError> { + let complete_text = self.complete_text.take(); + let tool_calls = self.tool_calls.take(); + let usage = self.usage.take(); + let errors = self.errors.take().map(|e| { + e.into_iter() + .map(|e| e.to_string()) + .collect::>() + }); + + let assistant_meta = AssistantMeta { + provider_id, + provider_options: Some(provider_options), + tool_calls, + usage, + errors, + partial: cancelled.then_some(true), + }; + db_service + .save_message(NewChatRsMessage { + session_id: &self.session_id, + role: ChatRsMessageRole::Assistant, + content: &complete_text.unwrap_or_default(), + meta: ChatRsMessageMeta::new_assistant(assistant_meta), + }) + .await?; + + Ok(()) + } } From 2c343ae3eba1f7ebae0138fd7dbaaa2c7d0aac56 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 22 Aug 2025 20:15:56 -0400 Subject: [PATCH 08/46] server: tweak redis connection settings --- server/src/auth/session.rs | 15 +++------------ server/src/redis.rs | 37 ++++++++++++++++++++++--------------- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/server/src/auth/session.rs b/server/src/auth/session.rs index 699ba6a..0e9a6c9 100644 --- a/server/src/auth/session.rs +++ b/server/src/auth/session.rs @@ -1,11 +1,11 @@ -use std::{ops::Deref, time::Duration}; +use std::ops::Deref; use chrono::Utc; use rocket::fairing::AdHoc; use rocket_flex_session::{storage::redis::RedisFredStorage, RocketFlexSession}; use uuid::Uuid; -use crate::config::get_app_config; +use crate::{config::get_app_config, redis::build_redis_pool}; const USER_ID_KEY: &str = "user_id"; const USER_ID_BYTES_KEY: &str = "user_id_bytes"; @@ -74,16 +74,7 @@ pub fn setup_session() -> AdHoc { let app_config = get_app_config(&rocket); let config = fred::prelude::Config::from_url(&app_config.redis_url) .expect("RS_CHAT_REDIS_URL should be valid Redis URL"); - let session_redis_pool = fred::prelude::Builder::from_config(config) - .with_connection_config(|config| { - config.connection_timeout = Duration::from_secs(4); - config.tcp = fred::prelude::TcpConfig { - nodelay: Some(true), - ..Default::default() - }; - }) - .build_pool(app_config.redis_pool.unwrap_or(2)) - .expect("Failed to build Redis session pool"); + let session_redis_pool = build_redis_pool(config, 2).expect("Failed to build Redis pool"); let session_fairing: RocketFlexSession = RocketFlexSession::builder() .with_options(|opt| { opt.cookie_name = "auth_rs_chat".to_string(); diff --git a/server/src/redis.rs b/server/src/redis.rs index 656f497..77ed90d 100644 --- a/server/src/redis.rs +++ b/server/src/redis.rs @@ -15,21 +15,7 @@ pub fn setup_redis() -> AdHoc { let app_config = get_app_config(&rocket); let config = Config::from_url(&app_config.redis_url) .expect("RS_CHAT_REDIS_URL should be valid Redis URL"); - let pool = Builder::from_config(config) - .with_connection_config(|config| { - config.connection_timeout = Duration::from_secs(4); - config.internal_command_timeout = Duration::from_secs(6); - config.max_command_attempts = 2; - config.tcp = TcpConfig { - nodelay: Some(true), - ..Default::default() - }; - }) - .set_policy(ReconnectPolicy::new_linear(5, 4000, 1000)) - .with_performance_config(|config| { - config.default_command_timeout = Duration::from_secs(6); - }) - .build_pool(app_config.redis_pool.unwrap_or(4)) + let pool = build_redis_pool(config, app_config.redis_pool.unwrap_or(4)) .expect("Failed to build Redis pool"); pool.init().await.expect("Failed to connect to Redis"); @@ -48,3 +34,24 @@ pub fn setup_redis() -> AdHoc { })) }) } + +pub fn build_redis_pool( + redis_config: Config, + pool_size: usize, +) -> Result { + Builder::from_config(redis_config) + .with_connection_config(|config| { + config.connection_timeout = Duration::from_secs(4); + config.internal_command_timeout = Duration::from_secs(6); + config.max_command_attempts = 2; + config.tcp = TcpConfig { + nodelay: Some(true), + ..Default::default() + }; + }) + .set_policy(ReconnectPolicy::new_linear(0, 10_000, 1000)) + .with_performance_config(|config| { + config.default_command_timeout = Duration::from_secs(6); + }) + .build_pool(pool_size) +} From 4b5f0d10871c1efc6d7fb9b541bd67698882dbc7 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 22 Aug 2025 20:26:00 -0400 Subject: [PATCH 09/46] server: tweak initial string capacities in stream writer --- server/src/stream/llm_writer.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 3dd6106..4c33836 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -21,8 +21,8 @@ use crate::{ /// Interval at which chunks are flushed to Redis. const FLUSH_INTERVAL: Duration = Duration::from_millis(500); -/// Max accumulated size of the text before it is automatically flushed to Redis. -const MAX_CHUNK_SIZE: usize = 1000; +/// Max accumulated size of the text chunk before it is automatically flushed to Redis. +const MAX_CHUNK_SIZE: usize = 400; /// Expiration in seconds set on the Redis stream (normally, the Redis stream will be deleted before this) const STREAM_EXPIRE: i64 = 30; @@ -164,7 +164,7 @@ impl LlmStreamWriter { .get_or_insert_with(|| String::with_capacity(MAX_CHUNK_SIZE + 200)) .push_str(text); self.complete_text - .get_or_insert_with(|| String::with_capacity(2000)) + .get_or_insert_with(|| String::with_capacity(500)) .push_str(text); } From d940fec60a3376e911a888ad1ed9f0cc74ada963 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 22 Aug 2025 21:00:47 -0400 Subject: [PATCH 10/46] server: fix stream cancellation --- server/src/redis.rs | 2 +- server/src/stream/llm_writer.rs | 18 ++++++++---------- server/src/stream/reader.rs | 4 +++- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/server/src/redis.rs b/server/src/redis.rs index 77ed90d..0acd616 100644 --- a/server/src/redis.rs +++ b/server/src/redis.rs @@ -51,7 +51,7 @@ pub fn build_redis_pool( }) .set_policy(ReconnectPolicy::new_linear(0, 10_000, 1000)) .with_performance_config(|config| { - config.default_command_timeout = Duration::from_secs(6); + config.default_command_timeout = Duration::from_secs(3); }) .build_pool(pool_size) } diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 4c33836..67a9d7a 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -101,6 +101,13 @@ impl LlmStreamWriter { pipeline.all().await } + /// Cancel stream by adding a `cancel` event to the stream and then deleting it from Redis. + pub async fn cancel(&self) -> Result<(), fred::prelude::Error> { + let entry: HashMap = RedisStreamChunk::Cancel.into(); + let _: () = self.redis.xadd(&self.key, true, None, "*", entry).await?; + self.redis.del(&self.key).await + } + /// Process the incoming stream from the LLM provider, intermittently flushing /// chunks to a Redis stream, and saving the final accumulated response to the database. pub async fn process( @@ -224,7 +231,7 @@ impl LlmStreamWriter { self.add_to_redis_stream(entries).await } - /// Adds a new entry to the Redis stream. Returns an error if the stream has been deleted or cancelled. + /// Adds a new entry to the Redis stream. Returns a `LlmError::StreamNotFound` error if the stream has been deleted or cancelled. async fn add_to_redis_stream( &self, entries: Vec>, @@ -254,15 +261,6 @@ impl LlmStreamWriter { pipeline.all().await } - /// Cancel stream by adding a `cancel` event to the stream and then deleting it from Redis. - pub async fn cancel(&self) -> Result<(), fred::prelude::Error> { - let entry: HashMap = RedisStreamChunk::Cancel.into(); - let pipeline = self.redis.next().pipeline(); - let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; - let _: () = pipeline.del(&self.key).await?; - pipeline.all().await - } - /// Saves the accumulated response to the database async fn save_to_db( &mut self, diff --git a/server/src/stream/reader.rs b/server/src/stream/reader.rs index 35aeeaf..6e54eac 100644 --- a/server/src/stream/reader.rs +++ b/server/src/stream/reader.rs @@ -117,7 +117,9 @@ impl SseStreamReader { match events.pop() { Some((id, data)) => { *last_event_id = id.clone(); - let is_end = data.get("type").is_some_and(|t| t == "end"); + let is_end = data + .get("type") + .is_some_and(|t| t == "end" || t == "cancel"); Ok((id, data, is_end)) } None => Err(LlmError::NoStreamEvent), From a5e900d739f8a8be8477b86ce1c86e91c6ee1be6 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 22 Aug 2025 21:22:11 -0400 Subject: [PATCH 11/46] server: fix get ongoing chat streams --- server/src/api/chat.rs | 6 +++--- server/src/stream/reader.rs | 12 ++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index b012a7b..c7a0e1f 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -43,7 +43,7 @@ pub fn get_routes(settings: &OpenApiSettings) -> (Vec, OpenApi) { #[derive(Debug, JsonSchema, serde::Serialize)] pub struct GetChatStreamsResponse { - streams: Vec, + sessions: Vec, } /// # Get chat streams @@ -55,8 +55,8 @@ pub async fn get_chat_streams( redis: &State, ) -> Result, ApiError> { let stream_reader = SseStreamReader::new(&redis); - let keys = stream_reader.get_chat_streams(&user_id).await?; - Ok(Json(GetChatStreamsResponse { streams: keys })) + let sessions = stream_reader.get_chat_streams(&user_id).await?; + Ok(Json(GetChatStreamsResponse { sessions })) } #[derive(JsonSchema, serde::Deserialize)] diff --git a/server/src/stream/reader.rs b/server/src/stream/reader.rs index 6e54eac..5bed16e 100644 --- a/server/src/stream/reader.rs +++ b/server/src/stream/reader.rs @@ -28,14 +28,18 @@ impl SseStreamReader { } } - /// Get the ongoing chat streams for a user. + /// Get the ongoing chat stream sessions for a user. pub async fn get_chat_streams(&self, user_id: &Uuid) -> Result, LlmError> { - let pattern = format!("{}:*", get_chat_stream_prefix(user_id)); - let keys = self + let prefix = get_chat_stream_prefix(user_id); + let pattern = format!("{}:*", prefix); + let (_, keys): (String, Vec) = self .redis .scan_page("0", &pattern, Some(20), Some(ScanType::Stream)) .await?; - Ok(keys) + Ok(keys + .into_iter() + .filter_map(|key| Some(key.strip_prefix(&format!("{}:", prefix))?.to_string())) + .collect()) } /// Retrieve the previous events from the given Redis stream. From 29008f2273973d70019810725718896f5a2f347d Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 23 Aug 2025 00:40:01 -0400 Subject: [PATCH 12/46] server: redis tweaks, add ping during SSE stream --- server/src/redis.rs | 2 +- server/src/stream/llm_writer.rs | 22 +++++++++++++--------- server/src/stream/reader.rs | 2 +- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/server/src/redis.rs b/server/src/redis.rs index 0acd616..c75a61f 100644 --- a/server/src/redis.rs +++ b/server/src/redis.rs @@ -51,7 +51,7 @@ pub fn build_redis_pool( }) .set_policy(ReconnectPolicy::new_linear(0, 10_000, 1000)) .with_performance_config(|config| { - config.default_command_timeout = Duration::from_secs(3); + config.default_command_timeout = Duration::from_secs(10); }) .build_pool(pool_size) } diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 67a9d7a..1a887cc 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -19,10 +19,10 @@ use crate::{ stream::get_chat_stream_key, }; -/// Interval at which chunks are flushed to Redis. +/// Interval at which chunks are flushed to the Redis stream. const FLUSH_INTERVAL: Duration = Duration::from_millis(500); /// Max accumulated size of the text chunk before it is automatically flushed to Redis. -const MAX_CHUNK_SIZE: usize = 400; +const MAX_CHUNK_SIZE: usize = 500; /// Expiration in seconds set on the Redis stream (normally, the Redis stream will be deleted before this) const STREAM_EXPIRE: i64 = 30; @@ -59,6 +59,7 @@ struct ChunkState { #[serde(tag = "type", content = "data", rename_all = "snake_case")] enum RedisStreamChunk { Start, + Ping, Text(String), ToolCall(String), Error(String), @@ -89,7 +90,9 @@ impl LlmStreamWriter { /// Check if the Redis stream already exists. pub async fn exists(&self) -> Result { - self.redis.exists(&self.key).await + let first_entry: Option = + self.redis.xread(Some(1), None, &self.key, "0-0").await?; + Ok(first_entry.is_some()) } /// Create the Redis stream and write a `start` entry. @@ -168,10 +171,10 @@ impl LlmStreamWriter { fn process_text(&mut self, text: &str) { self.current_chunk .text - .get_or_insert_with(|| String::with_capacity(MAX_CHUNK_SIZE + 200)) + .get_or_insert_with(|| String::with_capacity(MAX_CHUNK_SIZE)) .push_str(text); self.complete_text - .get_or_insert_with(|| String::with_capacity(500)) + .get_or_insert_with(|| String::with_capacity(MAX_CHUNK_SIZE)) .push_str(text); } @@ -205,10 +208,8 @@ impl LlmStreamWriter { if self.current_chunk.tool_calls.is_some() || self.current_chunk.error.is_some() { return true; } - if let Some(ref text) = self.current_chunk.text { - return text.len() > MAX_CHUNK_SIZE || last_flush_time.elapsed() > FLUSH_INTERVAL; - } - return false; + let text = self.current_chunk.text.as_ref(); + text.is_some_and(|t| t.len() > MAX_CHUNK_SIZE) || last_flush_time.elapsed() > FLUSH_INTERVAL } async fn flush_chunk(&mut self) -> Result<(), LlmError> { @@ -226,6 +227,9 @@ impl LlmStreamWriter { if let Some(error) = chunk_state.error { chunks.push(RedisStreamChunk::Error(error)); } + if chunks.is_empty() { + chunks.push(RedisStreamChunk::Ping); + } let entries = chunks.into_iter().map(|chunk| chunk.into()).collect(); self.add_to_redis_stream(entries).await diff --git a/server/src/stream/reader.rs b/server/src/stream/reader.rs index 5bed16e..7185b28 100644 --- a/server/src/stream/reader.rs +++ b/server/src/stream/reader.rs @@ -14,7 +14,7 @@ use crate::{ }; /// Timeout for the blocking `xread` command. -const XREAD_BLOCK_TIMEOUT: u64 = 10_000; // 10 seconds +const XREAD_BLOCK_TIMEOUT: u64 = 5_000; // 5 seconds /// Utility for reading SSE events from a Redis stream. pub struct SseStreamReader { From 8105dc5ea2b8bd50dd5d815c1be0db5a23af8c7f Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 23 Aug 2025 01:21:24 -0400 Subject: [PATCH 13/46] server: add timeout if LLM stream stops sending data --- server/src/provider.rs | 2 ++ server/src/stream/llm_writer.rs | 18 +++++++++++++----- server/src/stream/reader.rs | 2 +- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/server/src/provider.rs b/server/src/provider.rs index 09ff351..1ffb5c4 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -45,6 +45,8 @@ pub enum LlmError { NoStreamEvent, #[error("Client disconnected")] ClientDisconnected, + #[error("Timeout waiting for provider")] + StreamTimeout, #[error("Encryption error")] EncryptionError, #[error("Decryption error")] diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 1a887cc..b156819 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -25,6 +25,8 @@ const FLUSH_INTERVAL: Duration = Duration::from_millis(500); const MAX_CHUNK_SIZE: usize = 500; /// Expiration in seconds set on the Redis stream (normally, the Redis stream will be deleted before this) const STREAM_EXPIRE: i64 = 30; +/// Timeout waiting for data from the LLM stream. +const LLM_TIMEOUT: Duration = Duration::from_secs(20); /// Utility for processing an incoming LLM response stream and writing to a Redis stream. #[derive(Debug)] @@ -122,9 +124,10 @@ impl LlmStreamWriter { ) { let mut last_flush_time = Instant::now(); let mut cancelled = false; - while let Some(chunk) = stream.next().await { - match chunk { - Ok(chunk) => { + + loop { + match tokio::time::timeout(LLM_TIMEOUT, stream.next()).await { + Ok(Some(Ok(chunk))) => { if let Some(ref text) = chunk.text { self.process_text(text); } @@ -135,13 +138,18 @@ impl LlmStreamWriter { self.process_usage(usage_chunk); } } - Err(err) => { - self.process_error(err); + Ok(Some(Err(err))) => self.process_error(err), + Ok(None) => break, + Err(_) => { + self.process_error(LlmError::StreamTimeout); + cancelled = true; + break; } } if self.should_flush(&last_flush_time) { if let Err(err) = self.flush_chunk().await { + // Check if stream has been cancelled if matches!(err, LlmError::StreamNotFound) { self.errors.get_or_insert_default().push(err); cancelled = true; diff --git a/server/src/stream/reader.rs b/server/src/stream/reader.rs index 7185b28..8b256cb 100644 --- a/server/src/stream/reader.rs +++ b/server/src/stream/reader.rs @@ -102,7 +102,7 @@ impl SseStreamReader { /// Get the next event from the given Redis stream using a blocking `xread` command. /// - Updates the last event ID /// - Cancels waiting for the next event if the client disconnects - /// - Returns the event ID, data, and a `bool` indicating whether it's the ending event + /// - Returns the event ID, data, and a `bool` indicating whether it's an ending event async fn get_next_event( &self, key: &str, From 9bd16cea8bda1160dd252493b5757a5b3e0ce5f8 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 23 Aug 2025 01:52:29 -0400 Subject: [PATCH 14/46] server: add ping to Redis stream --- server/src/provider.rs | 2 +- server/src/stream/llm_writer.rs | 31 ++++++++++++++++++++++++++----- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/server/src/provider.rs b/server/src/provider.rs index 1ffb5c4..7df7869 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -45,7 +45,7 @@ pub enum LlmError { NoStreamEvent, #[error("Client disconnected")] ClientDisconnected, - #[error("Timeout waiting for provider")] + #[error("Timeout waiting for provider response")] StreamTimeout, #[error("Encryption error")] EncryptionError, diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index b156819..ff82f78 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -27,6 +27,8 @@ const MAX_CHUNK_SIZE: usize = 500; const STREAM_EXPIRE: i64 = 30; /// Timeout waiting for data from the LLM stream. const LLM_TIMEOUT: Duration = Duration::from_secs(20); +/// Interval for sending ping messages to keep the Redis stream alive. +const PING_INTERVAL: Duration = Duration::from_millis(1000); /// Utility for processing an incoming LLM response stream and writing to a Redis stream. #[derive(Debug)] @@ -122,9 +124,10 @@ impl LlmStreamWriter { provider_id: i32, provider_options: LlmApiProviderSharedOptions, ) { + let ping_handle = self.start_ping_task(); + let mut last_flush_time = Instant::now(); let mut cancelled = false; - loop { match tokio::time::timeout(LLM_TIMEOUT, stream.next()).await { Ok(Some(Ok(chunk))) => { @@ -142,14 +145,13 @@ impl LlmStreamWriter { Ok(None) => break, Err(_) => { self.process_error(LlmError::StreamTimeout); - cancelled = true; break; } } if self.should_flush(&last_flush_time) { if let Err(err) = self.flush_chunk().await { - // Check if stream has been cancelled + // Check if stream has been cancelled or deleted if matches!(err, LlmError::StreamNotFound) { self.errors.get_or_insert_default().push(err); cancelled = true; @@ -161,6 +163,7 @@ impl LlmStreamWriter { } } + ping_handle.abort(); if let Err(e) = self .save_to_db(db, provider_id, provider_options, cancelled) .await @@ -170,7 +173,6 @@ impl LlmStreamWriter { self.flush_chunk().await.ok(); } } - if !cancelled { self.finish().await.ok(); } @@ -236,7 +238,7 @@ impl LlmStreamWriter { chunks.push(RedisStreamChunk::Error(error)); } if chunks.is_empty() { - chunks.push(RedisStreamChunk::Ping); + return Ok(()); } let entries = chunks.into_iter().map(|chunk| chunk.into()).collect(); @@ -309,4 +311,23 @@ impl LlmStreamWriter { Ok(()) } + + /// Start task that pings the Redis stream every PING_INTERVAL seconds + fn start_ping_task(&self) -> tokio::task::JoinHandle<()> { + let redis = self.redis.clone(); + let key = self.key.to_owned(); + let ping_handle = tokio::spawn(async move { + let mut interval = tokio::time::interval(PING_INTERVAL); + loop { + interval.tick().await; + let entry: HashMap = RedisStreamChunk::Ping.into(); + let res: Result<(), fred::error::Error> = + redis.xadd(&key, true, None, "*", entry).await; + if res.is_err() { + break; + } + } + }); + ping_handle + } } From 800b906d4351ea709c6fbcec3238ba0233fcf948 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 23 Aug 2025 02:02:17 -0400 Subject: [PATCH 15/46] server: final stream tweaks --- server/src/stream/llm_writer.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index ff82f78..cc3c779 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -22,7 +22,7 @@ use crate::{ /// Interval at which chunks are flushed to the Redis stream. const FLUSH_INTERVAL: Duration = Duration::from_millis(500); /// Max accumulated size of the text chunk before it is automatically flushed to Redis. -const MAX_CHUNK_SIZE: usize = 500; +const MAX_CHUNK_SIZE: usize = 200; /// Expiration in seconds set on the Redis stream (normally, the Redis stream will be deleted before this) const STREAM_EXPIRE: i64 = 30; /// Timeout waiting for data from the LLM stream. @@ -184,7 +184,7 @@ impl LlmStreamWriter { .get_or_insert_with(|| String::with_capacity(MAX_CHUNK_SIZE)) .push_str(text); self.complete_text - .get_or_insert_with(|| String::with_capacity(MAX_CHUNK_SIZE)) + .get_or_insert_with(|| String::with_capacity(1024)) .push_str(text); } @@ -252,7 +252,9 @@ impl LlmStreamWriter { ) -> Result<(), LlmError> { let pipeline = self.redis.next().pipeline(); for entry in entries { - let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; + let _: () = pipeline + .xadd(&self.key, true, ("MAXLEN", "~", 500), "*", entry) + .await?; } let _: () = pipeline.expire(&self.key, STREAM_EXPIRE, None).await?; let res: Vec = pipeline.all().await?; From 16195b6b04d98ab2649a152f82b557e116192352 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 23 Aug 2025 02:17:12 -0400 Subject: [PATCH 16/46] server: moar tweaks --- server/src/stream/llm_writer.rs | 13 ++++++------- server/src/stream/reader.rs | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index cc3c779..2b21fb9 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -27,8 +27,8 @@ const MAX_CHUNK_SIZE: usize = 200; const STREAM_EXPIRE: i64 = 30; /// Timeout waiting for data from the LLM stream. const LLM_TIMEOUT: Duration = Duration::from_secs(20); -/// Interval for sending ping messages to keep the Redis stream alive. -const PING_INTERVAL: Duration = Duration::from_millis(1000); +/// Interval for sending ping messages to the Redis stream. +const PING_INTERVAL: Duration = Duration::from_secs(2); /// Utility for processing an incoming LLM response stream and writing to a Redis stream. #[derive(Debug)] @@ -94,8 +94,7 @@ impl LlmStreamWriter { /// Check if the Redis stream already exists. pub async fn exists(&self) -> Result { - let first_entry: Option = - self.redis.xread(Some(1), None, &self.key, "0-0").await?; + let first_entry: Option<()> = self.redis.xread(Some(1), None, &self.key, "0-0").await?; Ok(first_entry.is_some()) } @@ -108,7 +107,8 @@ impl LlmStreamWriter { pipeline.all().await } - /// Cancel stream by adding a `cancel` event to the stream and then deleting it from Redis. + /// Cancel the current stream by adding a `cancel` event to the stream and then deleting it from Redis + /// (not using a pipeline since we need to ensure the `cancel` event is processed before deleting the stream). pub async fn cancel(&self) -> Result<(), fred::prelude::Error> { let entry: HashMap = RedisStreamChunk::Cancel.into(); let _: () = self.redis.xadd(&self.key, true, None, "*", entry).await?; @@ -151,7 +151,6 @@ impl LlmStreamWriter { if self.should_flush(&last_flush_time) { if let Err(err) = self.flush_chunk().await { - // Check if stream has been cancelled or deleted if matches!(err, LlmError::StreamNotFound) { self.errors.get_or_insert_default().push(err); cancelled = true; @@ -314,7 +313,7 @@ impl LlmStreamWriter { Ok(()) } - /// Start task that pings the Redis stream every PING_INTERVAL seconds + /// Start task that pings the Redis stream every `PING_INTERVAL` seconds fn start_ping_task(&self) -> tokio::task::JoinHandle<()> { let redis = self.redis.clone(); let key = self.key.to_owned(); diff --git a/server/src/stream/reader.rs b/server/src/stream/reader.rs index 8b256cb..809d7bb 100644 --- a/server/src/stream/reader.rs +++ b/server/src/stream/reader.rs @@ -13,7 +13,7 @@ use crate::{ stream::{get_chat_stream_key, get_chat_stream_prefix}, }; -/// Timeout for the blocking `xread` command. +/// Timeout in milliseconds for the blocking `xread` command. const XREAD_BLOCK_TIMEOUT: u64 = 5_000; // 5 seconds /// Utility for reading SSE events from a Redis stream. From 2337b4903acbe16009065b1ad4cf62bdda1fb9d2 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 23 Aug 2025 02:45:07 -0400 Subject: [PATCH 17/46] add architecture document --- ARCHITECTURE.md | 106 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 ARCHITECTURE.md diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..500f6e2 --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,106 @@ +# RsChat Architecture + +## Overview + +RsChat is a real-time chat application that provides resumable streaming conversations with LLM providers. The architecture is designed for high performance, scalability across multiple server instances, and resilient streaming that can survive network interruptions. + +## Core Architecture + +### Frontend (React/TypeScript) +- **Location**: `web/` +- **Streaming**: Server-Sent Events (SSE) +- **State Management**: React Context for streaming state (`web/src/lib/context/StreamingContext.tsx`) +- **API Integration**: Type-safe API calls (generated from OpenAPI: `web/src/lib/api/types.d.ts`) + +### Backend (Rust/Rocket) +- **Location**: `server/` +- **Framework**: Rocket with async/await support +- **Database**: PostgreSQL for persistent storage +- **Cache/Streaming**: Redis for stream management and caching + +## LLM Streaming Architecture + +### Dual-Stream Approach + +RsChat uses a hybrid streaming architecture that provides both real-time performance and cross-instance resumability: + +1. **Server**: Redis Streams for resumability and multi-instance support +2. **Client**: Server-Sent Events (SSE) read from the Redis streams + +### Key Components + +#### 1. LlmStreamWriter (`server/src/stream/llm_writer.rs`) + +The core component that processes LLM provider streams and manages Redis stream output. + +**Key Features:** +- **Batching**: Accumulates chunks from the provider stream, up to a max length or timeout +- **Background Pings**: Sends regular keepalive pings +- **Timeout Detection**: 20-second timeout for idle LLM streams +- **Database Integration**: Saves final responses to PostgreSQL + +#### 2. Redis and SSE Stream Structure + +**Redis Key for Chat Streams**: `user:{user_id}:chat:{session_id}` + +**Chat Stream Message Types**: +- `start`: Stream initialization +- `text`: Accumulated text chunks +- `tool_call`: LLM tool invocations (JSON stringified) +- `error`: Error messages +- `ping`: Keepalive messages +- `end`: Stream completion +- `cancel`: Stream cancellation + +#### 3. Stream Lifecycle + +``` +Client Request → SSE Connection → LlmStreamWriter.create() + ↓ +LLM Provider Stream → Batching Data Chunks → Redis XADD + ↓ +Background Ping Task (intervals) + ↓ +Stream End → Database Save → Redis DEL +``` + + +### Resumability Features + +#### Cross-Instance Support +- Redis streams provide shared state across server instances +- Background ping tasks maintain stream liveness +- Stream cancellation detected via Redis XADD failures + +## Data Flow + +### 1. New Chat Request +``` +Client → POST /api/chat/{session_id} + → Send request to LLM Provider + → SSE Response Stream created + → LlmStreamWriter.create() + → Redis Stream created +``` + +### 2. Stream Processing +``` +LLM Chunk → Process text, tool calls, usage, and error chunks + → Batching Logic + → Redis XADD (if conditions met) + → Continue SSE Stream +``` + +### 3. Stream Completion +``` +LLM End → Final Database Save + → Redis Stream End Event + → Redis Stream Cleanup + → SSE Connection Close +``` + +### 4. Reconnection/Resume +``` +Client Reconnect → Check ongoing streams via GET /api/chat/streams + → Reconnect to stream (if active) +``` From e25d43004092c821ebbfb9fee06abb17efba02e6 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 23 Aug 2025 03:03:20 -0400 Subject: [PATCH 18/46] web: refactor/organize React context for streams --- web/src/components/Sidebar.tsx | 2 +- .../components/chat/ChatStreamingMessages.tsx | 2 +- .../chat/ChatStreamingToolCalls.tsx | 2 +- .../chat/messages/ChatMessageToolCalls.tsx | 2 +- web/src/lib/context/StreamingContext.tsx | 421 +----------------- web/src/lib/context/chats.ts | 45 ++ web/src/lib/context/index.ts | 10 + web/src/lib/context/streamManager.ts | 351 +++++++++++++++ web/src/lib/context/tools.ts | 31 ++ .../app/_appLayout/session/$sessionId.tsx | 5 +- 10 files changed, 444 insertions(+), 427 deletions(-) create mode 100644 web/src/lib/context/chats.ts create mode 100644 web/src/lib/context/index.ts create mode 100644 web/src/lib/context/streamManager.ts create mode 100644 web/src/lib/context/tools.ts diff --git a/web/src/components/Sidebar.tsx b/web/src/components/Sidebar.tsx index 1afa0c3..77ebb43 100644 --- a/web/src/components/Sidebar.tsx +++ b/web/src/components/Sidebar.tsx @@ -39,7 +39,7 @@ import { } from "@/components/ui/sidebar"; import { useCreateChatSession } from "@/lib/api/session"; import type { components } from "@/lib/api/types"; -import type { StreamedChat } from "@/lib/context/StreamingContext"; +import type { StreamedChat } from "@/lib/context"; import { Avatar, AvatarFallback, AvatarImage } from "./ui/avatar"; import { Button } from "./ui/button"; import { diff --git a/web/src/components/chat/ChatStreamingMessages.tsx b/web/src/components/chat/ChatStreamingMessages.tsx index ab0890e..d536999 100644 --- a/web/src/components/chat/ChatStreamingMessages.tsx +++ b/web/src/components/chat/ChatStreamingMessages.tsx @@ -3,7 +3,7 @@ import { useEffect } from "react"; import Markdown from "react-markdown"; import useSmoothStreaming from "@/hooks/useSmoothStreaming"; -import type { StreamedChat } from "@/lib/context/StreamingContext"; +import type { StreamedChat } from "@/lib/context"; import { cn } from "@/lib/utils"; import { ChatBubble, diff --git a/web/src/components/chat/ChatStreamingToolCalls.tsx b/web/src/components/chat/ChatStreamingToolCalls.tsx index 72604f0..82645ff 100644 --- a/web/src/components/chat/ChatStreamingToolCalls.tsx +++ b/web/src/components/chat/ChatStreamingToolCalls.tsx @@ -16,7 +16,7 @@ import { } from "@/components/ui/collapsible"; import useSmoothStreaming from "@/hooks/useSmoothStreaming"; import type { components } from "@/lib/api/types"; -import type { StreamedToolExecution } from "@/lib/context/StreamingContext"; +import type { StreamedToolExecution } from "@/lib/context"; import { getToolFromToolCall } from "@/lib/tools"; import { cn, escapeBackticks } from "@/lib/utils"; import { useAutoScroll } from "../ui/chat/hooks/useAutoScroll"; diff --git a/web/src/components/chat/messages/ChatMessageToolCalls.tsx b/web/src/components/chat/messages/ChatMessageToolCalls.tsx index 9cd52f8..82a7695 100644 --- a/web/src/components/chat/messages/ChatMessageToolCalls.tsx +++ b/web/src/components/chat/messages/ChatMessageToolCalls.tsx @@ -4,7 +4,7 @@ import { lazy, Suspense, useMemo, useState } from "react"; import { getToolIcon, getToolTypeLabel } from "@/components/ToolsManager"; import { Button } from "@/components/ui/button"; import type { components } from "@/lib/api/types"; -import { useStreamingTools } from "@/lib/context/StreamingContext"; +import { useStreamingTools } from "@/lib/context"; import { getToolFromToolCall } from "@/lib/tools"; import { cn } from "@/lib/utils"; diff --git a/web/src/lib/context/StreamingContext.tsx b/web/src/lib/context/StreamingContext.tsx index da605cc..1e2dac3 100644 --- a/web/src/lib/context/StreamingContext.tsx +++ b/web/src/lib/context/StreamingContext.tsx @@ -1,428 +1,11 @@ -import { useQueryClient } from "@tanstack/react-query"; -import { createContext, useCallback, useContext, useState } from "react"; - -import { streamChat } from "../api/chat"; -import { chatSessionQueryKey, recentSessionsQueryKey } from "../api/session"; -import { streamToolExecution } from "../api/tool"; -import type { components } from "../api/types"; - -export interface StreamedChat { - content: string; - error?: string; - status: "streaming" | "completed"; -} - -export interface StreamedToolExecution { - result: string; - logs: string[]; - debugLogs: string[]; - error?: string; - status: "streaming" | "completed" | "error"; -} -const streamedToolExecutionInit = (): StreamedToolExecution => ({ - result: "", - logs: [], - debugLogs: [], - status: "streaming", -}); - -/** Hook to stream a chat, get chat stream status, etc. */ -export const useStreamingChats = () => { - const queryClient = useQueryClient(); - const { streamedChats, startStream } = useContext(ChatStreamContext); - - /** Start stream + optimistic update of user message */ - const onUserSubmit = useCallback( - (sessionId: string, input: components["schemas"]["SendChatInput"]) => { - startStream(sessionId, input); - if (!input.message) return; - - queryClient.setQueryData<{ - messages: components["schemas"]["ChatRsMessage"][]; - }>(["chatSession", { sessionId }], (oldData: any) => { - if (!oldData) return {}; - return { - ...oldData, - messages: [ - ...oldData.messages, - { - id: crypto.randomUUID(), - content: input.message, - role: "User", - created_at: new Date().toISOString(), - session_id: sessionId, - meta: {}, - }, - ], - }; - }); - }, - [startStream, queryClient], - ); - - return { - onUserSubmit, - streamedChats, - }; -}; - -/** Hook to stream tool executions */ -export const useStreamingTools = () => { - const { streamedTools, startToolExecution, cancelToolExecution } = - useContext(ChatStreamContext); - - /** Execute a tool with streaming */ - const onToolExecute = useCallback( - (messageId: string, sessionId: string, toolCallId: string) => { - startToolExecution(messageId, sessionId, toolCallId); - }, - [startToolExecution], - ); - - /** Cancel a tool execution */ - const onToolCancel = useCallback( - (sessionId: string, toolCallId: string) => { - cancelToolExecution(sessionId, toolCallId); - }, - [cancelToolExecution], - ); - - return { - streamedTools, - onToolExecute, - onToolCancel, - }; -}; - -/** Manage ongoing chat streams and tool executions */ -const useChatStreamManager = () => { - const [streamedChats, setStreamedChats] = useState<{ - [sessionId: string]: StreamedChat | undefined; - }>({}); - - const [streamedTools, setStreamedTools] = useState<{ - [toolCallId: string]: StreamedToolExecution | undefined; - }>({}); - - const [activeToolStreams, setActiveToolStreams] = useState<{ - [toolCallId: string]: { close: () => void } | undefined; - }>({}); - - const addChatPart = useCallback((sessionId: string, part: string) => { - setStreamedChats((prev) => ({ - ...prev, - [sessionId]: { - content: (prev?.[sessionId]?.content || "") + part, - error: prev?.[sessionId]?.error, - status: "streaming", - }, - })); - }, []); - - const addChatError = useCallback((sessionId: string, error: string) => { - setStreamedChats((prev) => ({ - ...prev, - [sessionId]: { - content: prev?.[sessionId]?.content || "", - status: "streaming", - error, - }, - })); - }, []); - - const setChatStatus = useCallback( - (sessionId: string, status: "streaming" | "completed") => { - setStreamedChats((prev) => ({ - ...prev, - [sessionId]: { - status, - content: prev?.[sessionId]?.content || "", - error: prev?.[sessionId]?.error, - }, - })); - }, - [], - ); - - const clearChat = useCallback((sessionId: string) => { - setStreamedChats((prev) => ({ - ...prev, - [sessionId]: undefined, - })); - }, []); - - const queryClient = useQueryClient(); - - const invalidateSession = useCallback( - async (sessionId: string) => { - await Promise.allSettled([ - queryClient.invalidateQueries({ - queryKey: chatSessionQueryKey(sessionId), - }), - queryClient.invalidateQueries({ - queryKey: recentSessionsQueryKey, - }), - ]); - }, - [queryClient], - ); - - /** Refetch chat session for the new assistant message */ - const refetchSessionForNewAssistantResponse = useCallback( - async (sessionId: string) => { - const retryDelay = 1000; // 1 second - try { - // Refetch chat session with retry loop - let hasNewAssistantMessage = false; - let retryCount = 0; - const maxRetries = 3; - - while (!hasNewAssistantMessage && retryCount < maxRetries) { - await invalidateSession(sessionId); - - // Check if the chat session has been updated with the new assistant response - const updatedData = queryClient.getQueryData<{ - messages: components["schemas"]["ChatRsMessage"][]; - }>(["chatSession", { sessionId }]); - hasNewAssistantMessage = - updatedData?.messages?.some( - (msg) => - msg.role === "Assistant" && - !msg.meta.assistant?.partial && - new Date(msg.created_at).getTime() > Date.now() - 5000, // Within last 5 seconds - ) || false; - - // Retry if no new assistant message - if (!hasNewAssistantMessage) { - retryCount++; - if (retryCount < maxRetries) { - await new Promise((resolve) => setTimeout(resolve, retryDelay)); - } - } - } - } catch (error) { - console.error("Error refetching chat session:", error); - await invalidateSession(sessionId); - } - }, - [invalidateSession, queryClient], - ); - - /** Refetch chat session for the new tool message */ - const refetchSessionForNewToolMessage = useCallback( - async (sessionId: string, toolCallId: string) => { - const retryDelay = 1000; // 1 second - try { - let hasNewToolMessage = false; - let retryCount = 0; - const maxRetries = 3; - - while (!hasNewToolMessage && retryCount < maxRetries) { - await invalidateSession(sessionId); - - const updatedData = queryClient.getQueryData<{ - messages: components["schemas"]["ChatRsMessage"][]; - }>(["chatSession", { sessionId }]); - hasNewToolMessage = - updatedData?.messages?.some( - (msg) => - msg.role === "Tool" && msg.meta.tool_call?.id === toolCallId, - ) || false; - - if (!hasNewToolMessage) { - retryCount++; - if (retryCount < maxRetries) { - await new Promise((resolve) => setTimeout(resolve, retryDelay)); - } - } - } - } catch (error) { - console.error("Error refetching chat session:", error); - await invalidateSession(sessionId); - } - }, - [invalidateSession, queryClient], - ); - - /** Start a new chat stream */ - const startStream = useCallback( - (sessionId: string, input: components["schemas"]["SendChatInput"]) => { - clearChat(sessionId); - setChatStatus(sessionId, "streaming"); - const stream = streamChat(sessionId, input, { - onPart: (part) => { - addChatPart(sessionId, part); - }, - onError: (error) => { - addChatError(sessionId, error); - }, - }); - stream.start - .then(() => { - refetchSessionForNewAssistantResponse(sessionId).then(() => - clearChat(sessionId), - ); - }) - .catch(() => { - invalidateSession(sessionId).then(() => - setChatStatus(sessionId, "completed"), - ); - }); - }, - [ - clearChat, - addChatPart, - addChatError, - setChatStatus, - invalidateSession, - refetchSessionForNewAssistantResponse, - ], - ); - - /** Add tool execution result chunk */ - const addToolResult = useCallback((toolCallId: string, result: string) => { - setStreamedTools((prev) => ({ - ...prev, - [toolCallId]: { - ...(prev?.[toolCallId] || streamedToolExecutionInit()), - result: (prev?.[toolCallId]?.result || "") + result, - }, - })); - }, []); - - /** Add tool execution log */ - const addToolLog = useCallback((toolCallId: string, log: string) => { - setStreamedTools((prev) => ({ - ...prev, - [toolCallId]: { - ...(prev?.[toolCallId] || streamedToolExecutionInit()), - logs: [...(prev?.[toolCallId]?.logs || []), log], - }, - })); - }, []); - - /** Add tool execution debug log */ - const addToolDebug = useCallback((toolCallId: string, debug: string) => { - setStreamedTools((prev) => ({ - ...prev, - [toolCallId]: { - ...(prev?.[toolCallId] || streamedToolExecutionInit()), - debugLogs: [...(prev?.[toolCallId]?.debugLogs || []), debug], - }, - })); - }, []); - - /** Add tool execution error */ - const addToolError = useCallback((toolCallId: string, error: string) => { - setStreamedTools((prev) => ({ - ...prev, - [toolCallId]: { - ...(prev?.[toolCallId] || streamedToolExecutionInit()), - error, - status: "error", - }, - })); - }, []); - - /** Set tool execution status */ - const setToolStatus = useCallback( - (toolCallId: string, status: "streaming" | "completed" | "error") => { - setStreamedTools((prev) => ({ - ...prev, - [toolCallId]: { - ...(prev?.[toolCallId] || streamedToolExecutionInit()), - status, - }, - })); - }, - [], - ); - - /** Clear active tool execution */ - const clearTool = useCallback((toolCallId: string) => { - setStreamedTools((prev) => ({ - ...prev, - [toolCallId]: undefined, - })); - setActiveToolStreams((prev) => ({ - ...prev, - [toolCallId]: undefined, - })); - }, []); - - /** Start tool execution stream */ - const startToolExecution = useCallback( - (messageId: string, sessionId: string, toolCallId: string) => { - const stream = streamToolExecution(messageId, toolCallId, { - onResult: (data) => addToolResult(toolCallId, data), - onLog: (data) => addToolLog(toolCallId, data), - onDebug: (data) => addToolDebug(toolCallId, data), - onError: (error) => addToolError(toolCallId, error), - }); - - clearTool(toolCallId); - setToolStatus(toolCallId, "streaming"); - setActiveToolStreams((prev) => ({ - ...prev, - [toolCallId]: { close: stream.close }, - })); - - stream.start - .then(() => setToolStatus(toolCallId, "completed")) - .catch(() => setToolStatus(toolCallId, "error")) - .finally(() => { - refetchSessionForNewToolMessage(sessionId, toolCallId).then(() => { - clearTool(toolCallId); - }); - }); - }, - [ - setToolStatus, - addToolResult, - addToolLog, - addToolDebug, - addToolError, - clearTool, - refetchSessionForNewToolMessage, - ], - ); - - /** Cancel tool execution */ - const cancelToolExecution = useCallback( - (sessionId: string, toolCallId: string) => { - const activeStream = activeToolStreams[toolCallId]; - if (activeStream) { - activeStream.close(); - refetchSessionForNewToolMessage(sessionId, toolCallId).then(() => { - clearTool(toolCallId); - }); - } - }, - [activeToolStreams, clearTool, refetchSessionForNewToolMessage], - ); - - return { - startStream, - streamedChats, - streamedTools, - startToolExecution, - cancelToolExecution, - }; -}; - -const ChatStreamContext = createContext< - ReturnType ->( - //@ts-expect-error should be initialized - null, -); +import { ChatStreamContext, useStreamManager } from "./streamManager"; export const ChatStreamProvider = ({ children, }: { children: React.ReactNode; }) => { - const chatStreamManager = useChatStreamManager(); + const chatStreamManager = useStreamManager(); return ( diff --git a/web/src/lib/context/chats.ts b/web/src/lib/context/chats.ts new file mode 100644 index 0000000..6879caa --- /dev/null +++ b/web/src/lib/context/chats.ts @@ -0,0 +1,45 @@ +import { useQueryClient } from "@tanstack/react-query"; +import { useCallback, useContext } from "react"; + +import type { components } from "../api/types"; +import { ChatStreamContext } from "./streamManager"; + +/** Hook to stream a chat, get chat stream status, etc. */ +export const useStreamingChats = () => { + const queryClient = useQueryClient(); + const { streamedChats, startStream } = useContext(ChatStreamContext); + + /** Start stream + optimistic update of user message */ + const onUserSubmit = useCallback( + (sessionId: string, input: components["schemas"]["SendChatInput"]) => { + startStream(sessionId, input); + if (!input.message) return; + + queryClient.setQueryData<{ + messages: components["schemas"]["ChatRsMessage"][]; + }>(["chatSession", { sessionId }], (oldData: any) => { + if (!oldData) return {}; + return { + ...oldData, + messages: [ + ...oldData.messages, + { + id: crypto.randomUUID(), + content: input.message, + role: "User", + created_at: new Date().toISOString(), + session_id: sessionId, + meta: {}, + }, + ], + }; + }); + }, + [startStream, queryClient], + ); + + return { + onUserSubmit, + streamedChats, + }; +}; diff --git a/web/src/lib/context/index.ts b/web/src/lib/context/index.ts new file mode 100644 index 0000000..79f59bd --- /dev/null +++ b/web/src/lib/context/index.ts @@ -0,0 +1,10 @@ +import { useStreamingChats } from "./chats"; +import type { StreamedChat, StreamedToolExecution } from "./streamManager"; +import { useStreamingTools } from "./tools"; + +export { + useStreamingTools, + useStreamingChats, + type StreamedChat, + type StreamedToolExecution, +}; diff --git a/web/src/lib/context/streamManager.ts b/web/src/lib/context/streamManager.ts new file mode 100644 index 0000000..95e5b1e --- /dev/null +++ b/web/src/lib/context/streamManager.ts @@ -0,0 +1,351 @@ +import { useQueryClient } from "@tanstack/react-query"; +import { createContext, useCallback, useState } from "react"; + +import { streamChat } from "../api/chat"; +import { chatSessionQueryKey, recentSessionsQueryKey } from "../api/session"; +import { streamToolExecution } from "../api/tool"; +import type { components } from "../api/types"; + +export interface StreamedChat { + content: string; + error?: string; + status: "streaming" | "completed"; +} + +export interface StreamedToolExecution { + result: string; + logs: string[]; + debugLogs: string[]; + error?: string; + status: "streaming" | "completed" | "error"; +} + +const streamedToolExecutionInit = (): StreamedToolExecution => ({ + result: "", + logs: [], + debugLogs: [], + status: "streaming", +}); + +/** Manage ongoing chat streams and tool executions */ +export const useStreamManager = () => { + const [streamedChats, setStreamedChats] = useState<{ + [sessionId: string]: StreamedChat | undefined; + }>({}); + + const [streamedTools, setStreamedTools] = useState<{ + [toolCallId: string]: StreamedToolExecution | undefined; + }>({}); + + const [activeToolStreams, setActiveToolStreams] = useState<{ + [toolCallId: string]: { close: () => void } | undefined; + }>({}); + + const addChatPart = useCallback((sessionId: string, part: string) => { + setStreamedChats((prev) => ({ + ...prev, + [sessionId]: { + content: (prev?.[sessionId]?.content || "") + part, + error: prev?.[sessionId]?.error, + status: "streaming", + }, + })); + }, []); + + const addChatError = useCallback((sessionId: string, error: string) => { + setStreamedChats((prev) => ({ + ...prev, + [sessionId]: { + content: prev?.[sessionId]?.content || "", + status: "streaming", + error, + }, + })); + }, []); + + const setChatStatus = useCallback( + (sessionId: string, status: "streaming" | "completed") => { + setStreamedChats((prev) => ({ + ...prev, + [sessionId]: { + status, + content: prev?.[sessionId]?.content || "", + error: prev?.[sessionId]?.error, + }, + })); + }, + [], + ); + + const clearChat = useCallback((sessionId: string) => { + setStreamedChats((prev) => ({ + ...prev, + [sessionId]: undefined, + })); + }, []); + + const queryClient = useQueryClient(); + + const invalidateSession = useCallback( + async (sessionId: string) => { + await Promise.allSettled([ + queryClient.invalidateQueries({ + queryKey: chatSessionQueryKey(sessionId), + }), + queryClient.invalidateQueries({ + queryKey: recentSessionsQueryKey, + }), + ]); + }, + [queryClient], + ); + + /** Refetch chat session for the new assistant message */ + const refetchSessionForNewAssistantResponse = useCallback( + async (sessionId: string) => { + const retryDelay = 1000; // 1 second + try { + // Refetch chat session with retry loop + let hasNewAssistantMessage = false; + let retryCount = 0; + const maxRetries = 3; + + while (!hasNewAssistantMessage && retryCount < maxRetries) { + await invalidateSession(sessionId); + + // Check if the chat session has been updated with the new assistant response + const updatedData = queryClient.getQueryData<{ + messages: components["schemas"]["ChatRsMessage"][]; + }>(["chatSession", { sessionId }]); + hasNewAssistantMessage = + updatedData?.messages?.some( + (msg) => + msg.role === "Assistant" && + !msg.meta.assistant?.partial && + new Date(msg.created_at).getTime() > Date.now() - 5000, // Within last 5 seconds + ) || false; + + // Retry if no new assistant message + if (!hasNewAssistantMessage) { + retryCount++; + if (retryCount < maxRetries) { + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + } + } + } + } catch (error) { + console.error("Error refetching chat session:", error); + await invalidateSession(sessionId); + } + }, + [invalidateSession, queryClient], + ); + + /** Refetch chat session for the new tool message */ + const refetchSessionForNewToolMessage = useCallback( + async (sessionId: string, toolCallId: string) => { + const retryDelay = 1000; // 1 second + try { + let hasNewToolMessage = false; + let retryCount = 0; + const maxRetries = 3; + + while (!hasNewToolMessage && retryCount < maxRetries) { + await invalidateSession(sessionId); + + const updatedData = queryClient.getQueryData<{ + messages: components["schemas"]["ChatRsMessage"][]; + }>(["chatSession", { sessionId }]); + hasNewToolMessage = + updatedData?.messages?.some( + (msg) => + msg.role === "Tool" && msg.meta.tool_call?.id === toolCallId, + ) || false; + + if (!hasNewToolMessage) { + retryCount++; + if (retryCount < maxRetries) { + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + } + } + } + } catch (error) { + console.error("Error refetching chat session:", error); + await invalidateSession(sessionId); + } + }, + [invalidateSession, queryClient], + ); + + /** Start a new chat stream */ + const startStream = useCallback( + (sessionId: string, input: components["schemas"]["SendChatInput"]) => { + clearChat(sessionId); + setChatStatus(sessionId, "streaming"); + const stream = streamChat(sessionId, input, { + onPart: (part) => { + addChatPart(sessionId, part); + }, + onError: (error) => { + addChatError(sessionId, error); + }, + }); + stream.start + .then(() => { + refetchSessionForNewAssistantResponse(sessionId).then(() => + clearChat(sessionId), + ); + }) + .catch(() => { + invalidateSession(sessionId).then(() => + setChatStatus(sessionId, "completed"), + ); + }); + }, + [ + clearChat, + addChatPart, + addChatError, + setChatStatus, + invalidateSession, + refetchSessionForNewAssistantResponse, + ], + ); + + /** Add tool execution result chunk */ + const addToolResult = useCallback((toolCallId: string, result: string) => { + setStreamedTools((prev) => ({ + ...prev, + [toolCallId]: { + ...(prev?.[toolCallId] || streamedToolExecutionInit()), + result: (prev?.[toolCallId]?.result || "") + result, + }, + })); + }, []); + + /** Add tool execution log */ + const addToolLog = useCallback((toolCallId: string, log: string) => { + setStreamedTools((prev) => ({ + ...prev, + [toolCallId]: { + ...(prev?.[toolCallId] || streamedToolExecutionInit()), + logs: [...(prev?.[toolCallId]?.logs || []), log], + }, + })); + }, []); + + /** Add tool execution debug log */ + const addToolDebug = useCallback((toolCallId: string, debug: string) => { + setStreamedTools((prev) => ({ + ...prev, + [toolCallId]: { + ...(prev?.[toolCallId] || streamedToolExecutionInit()), + debugLogs: [...(prev?.[toolCallId]?.debugLogs || []), debug], + }, + })); + }, []); + + /** Add tool execution error */ + const addToolError = useCallback((toolCallId: string, error: string) => { + setStreamedTools((prev) => ({ + ...prev, + [toolCallId]: { + ...(prev?.[toolCallId] || streamedToolExecutionInit()), + error, + status: "error", + }, + })); + }, []); + + /** Set tool execution status */ + const setToolStatus = useCallback( + (toolCallId: string, status: "streaming" | "completed" | "error") => { + setStreamedTools((prev) => ({ + ...prev, + [toolCallId]: { + ...(prev?.[toolCallId] || streamedToolExecutionInit()), + status, + }, + })); + }, + [], + ); + + /** Clear active tool execution */ + const clearTool = useCallback((toolCallId: string) => { + setStreamedTools((prev) => ({ + ...prev, + [toolCallId]: undefined, + })); + setActiveToolStreams((prev) => ({ + ...prev, + [toolCallId]: undefined, + })); + }, []); + + /** Start tool execution stream */ + const startToolExecution = useCallback( + (messageId: string, sessionId: string, toolCallId: string) => { + const stream = streamToolExecution(messageId, toolCallId, { + onResult: (data) => addToolResult(toolCallId, data), + onLog: (data) => addToolLog(toolCallId, data), + onDebug: (data) => addToolDebug(toolCallId, data), + onError: (error) => addToolError(toolCallId, error), + }); + + clearTool(toolCallId); + setToolStatus(toolCallId, "streaming"); + setActiveToolStreams((prev) => ({ + ...prev, + [toolCallId]: { close: stream.close }, + })); + + stream.start + .then(() => setToolStatus(toolCallId, "completed")) + .catch(() => setToolStatus(toolCallId, "error")) + .finally(() => { + refetchSessionForNewToolMessage(sessionId, toolCallId).then(() => { + clearTool(toolCallId); + }); + }); + }, + [ + setToolStatus, + addToolResult, + addToolLog, + addToolDebug, + addToolError, + clearTool, + refetchSessionForNewToolMessage, + ], + ); + + /** Cancel tool execution */ + const cancelToolExecution = useCallback( + (sessionId: string, toolCallId: string) => { + const activeStream = activeToolStreams[toolCallId]; + if (activeStream) { + activeStream.close(); + refetchSessionForNewToolMessage(sessionId, toolCallId).then(() => { + clearTool(toolCallId); + }); + } + }, + [activeToolStreams, clearTool, refetchSessionForNewToolMessage], + ); + + return { + startStream, + streamedChats, + streamedTools, + startToolExecution, + cancelToolExecution, + }; +}; + +export const ChatStreamContext = createContext< + ReturnType +>( + //@ts-expect-error should be initialized + null, +); diff --git a/web/src/lib/context/tools.ts b/web/src/lib/context/tools.ts new file mode 100644 index 0000000..1a284e4 --- /dev/null +++ b/web/src/lib/context/tools.ts @@ -0,0 +1,31 @@ +import { useCallback, useContext } from "react"; + +import { ChatStreamContext } from "./streamManager"; + +/** Hook to stream tool executions */ +export const useStreamingTools = () => { + const { streamedTools, startToolExecution, cancelToolExecution } = + useContext(ChatStreamContext); + + /** Execute a tool with streaming */ + const onToolExecute = useCallback( + (messageId: string, sessionId: string, toolCallId: string) => { + startToolExecution(messageId, sessionId, toolCallId); + }, + [startToolExecution], + ); + + /** Cancel a tool execution */ + const onToolCancel = useCallback( + (sessionId: string, toolCallId: string) => { + cancelToolExecution(sessionId, toolCallId); + }, + [cancelToolExecution], + ); + + return { + streamedTools, + onToolExecute, + onToolCancel, + }; +}; diff --git a/web/src/routes/app/_appLayout/session/$sessionId.tsx b/web/src/routes/app/_appLayout/session/$sessionId.tsx index 910c16e..a290980 100644 --- a/web/src/routes/app/_appLayout/session/$sessionId.tsx +++ b/web/src/routes/app/_appLayout/session/$sessionId.tsx @@ -19,10 +19,7 @@ import { } from "@/lib/api/session"; import { useTools } from "@/lib/api/tool"; import type { components } from "@/lib/api/types"; -import { - useStreamingChats, - useStreamingTools, -} from "@/lib/context/StreamingContext"; +import { useStreamingChats, useStreamingTools } from "@/lib/context"; export const Route = createFileRoute("/app/_appLayout/session/$sessionId")({ component: RouteComponent, From 0ba475a55aa3496e2051fca556885d852484b87b Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 23 Aug 2025 03:35:23 -0400 Subject: [PATCH 19/46] server: move db logic out of the stream writer --- server/src/api/chat.rs | 31 ++++++++-- server/src/provider.rs | 2 - server/src/stream/llm_writer.rs | 105 +++++++++----------------------- 3 files changed, 55 insertions(+), 83 deletions(-) diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index c7a0e1f..642bde3 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -19,8 +19,8 @@ use crate::{ auth::ChatRsUserId, db::{ models::{ - ChatRsMessageMeta, ChatRsMessageRole, ChatRsSessionMeta, NewChatRsMessage, - UpdateChatRsSession, + AssistantMeta, ChatRsMessageMeta, ChatRsMessageRole, ChatRsSessionMeta, + NewChatRsMessage, UpdateChatRsSession, }, services::{ChatDbService, ProviderDbService, ToolDbService}, DbConnection, DbPool, @@ -169,12 +169,31 @@ pub async fn send_chat_stream( let provider_options = input.options.clone(); // Create the Redis stream, then spawn a task to stream and save the response - stream_writer.create().await?; + stream_writer.start().await?; tokio::spawn(async move { - let mut chat_db_service = ChatDbService::new(&mut db); - stream_writer - .process(stream, &mut chat_db_service, provider_id, provider_options) + let (text, tool_calls, usage, errors, cancelled) = stream_writer.process(stream).await; + let assistant_meta = AssistantMeta { + provider_id, + provider_options: Some(provider_options), + tool_calls, + usage, + errors, + partial: cancelled.then_some(true), + }; + let db_result = ChatDbService::new(&mut db) + .save_message(NewChatRsMessage { + session_id: &session_id, + role: ChatRsMessageRole::Assistant, + content: &text.unwrap_or_default(), + meta: ChatRsMessageMeta::new_assistant(assistant_meta), + }) .await; + if let Err(err) = db_result { + rocket::error!("Failed to save assistant message: {}", err); + } + if !cancelled { + stream_writer.end().await.ok(); + } }); Ok(format!("Stream started at /api/chat/{}/stream", session_id)) diff --git a/server/src/provider.rs b/server/src/provider.rs index 7df7869..a604b19 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -53,8 +53,6 @@ pub enum LlmError { DecryptionError, #[error("Redis error: {0}")] Redis(#[from] fred::error::Error), - #[error("Failed to save message: {0}")] - Database(#[from] diesel::result::Error), } /// A streaming chunk of data from the LLM provider diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 2b21fb9..041b8b8 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -9,13 +9,8 @@ use serde::Serialize; use uuid::Uuid; use crate::{ - db::{ - models::{ - AssistantMeta, ChatRsMessageMeta, ChatRsMessageRole, ChatRsToolCall, NewChatRsMessage, - }, - services::ChatDbService, - }, - provider::{LlmApiProviderSharedOptions, LlmApiStream, LlmError, LlmUsage}, + db::models::ChatRsToolCall, + provider::{LlmApiStream, LlmError, LlmUsage}, stream::get_chat_stream_key, }; @@ -36,8 +31,6 @@ pub struct LlmStreamWriter { redis: fred::prelude::Pool, /// The key of the Redis stream. key: String, - /// The chat session ID associated with the stream. - session_id: Uuid, /// The current chunk of data being processed. current_chunk: ChunkState, /// Accumulated text response from the assistant. @@ -83,7 +76,6 @@ impl LlmStreamWriter { LlmStreamWriter { redis: redis.clone(), key: get_chat_stream_key(user_id, session_id), - session_id: session_id.to_owned(), current_chunk: ChunkState::default(), complete_text: None, tool_calls: None, @@ -99,7 +91,7 @@ impl LlmStreamWriter { } /// Create the Redis stream and write a `start` entry. - pub async fn create(&self) -> Result<(), fred::prelude::Error> { + pub async fn start(&self) -> Result<(), fred::prelude::Error> { let entry: HashMap = RedisStreamChunk::Start.into(); let pipeline = self.redis.next().pipeline(); let _: () = pipeline.xadd(&self.key, false, None, "*", entry).await?; @@ -115,14 +107,27 @@ impl LlmStreamWriter { self.redis.del(&self.key).await } + /// Add an `end` event to notify clients that the stream has ended, and then + /// delete the stream from Redis. + pub async fn end(&self) -> Result<(), fred::prelude::Error> { + let entry: HashMap = RedisStreamChunk::End.into(); + let pipeline = self.redis.next().pipeline(); + let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; + let _: () = pipeline.del(&self.key).await?; + pipeline.all().await + } + /// Process the incoming stream from the LLM provider, intermittently flushing - /// chunks to a Redis stream, and saving the final accumulated response to the database. + /// chunks to a Redis stream, and return the final accumulated response. pub async fn process( &mut self, mut stream: LlmApiStream, - db: &mut ChatDbService<'_>, - provider_id: i32, - provider_options: LlmApiProviderSharedOptions, + ) -> ( + Option, + Option>, + Option, + Option>, + bool, ) { let ping_handle = self.start_ping_task(); @@ -161,20 +166,17 @@ impl LlmStreamWriter { last_flush_time = Instant::now(); } } - ping_handle.abort(); - if let Err(e) = self - .save_to_db(db, provider_id, provider_options, cancelled) - .await - { - if !cancelled { - self.current_chunk.error = Some(e.to_string()); - self.flush_chunk().await.ok(); - } - } - if !cancelled { - self.finish().await.ok(); - } + + let complete_text = self.complete_text.take(); + let tool_calls = self.tool_calls.take(); + let usage = self.usage.take(); + let errors = self.errors.take().map(|e| { + e.into_iter() + .map(|e| e.to_string()) + .collect::>() + }); + (complete_text, tool_calls, usage, errors, cancelled) } fn process_text(&mut self, text: &str) { @@ -266,53 +268,6 @@ impl LlmStreamWriter { } } - /// Add an `end` event to notify clients that the stream has ended, and then - /// delete the stream from Redis. - async fn finish(&self) -> Result<(), fred::prelude::Error> { - let entry: HashMap = RedisStreamChunk::End.into(); - let pipeline = self.redis.next().pipeline(); - let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; - let _: () = pipeline.del(&self.key).await?; - pipeline.all().await - } - - /// Saves the accumulated response to the database - async fn save_to_db( - &mut self, - db_service: &mut ChatDbService<'_>, - provider_id: i32, - provider_options: LlmApiProviderSharedOptions, - cancelled: bool, - ) -> Result<(), LlmError> { - let complete_text = self.complete_text.take(); - let tool_calls = self.tool_calls.take(); - let usage = self.usage.take(); - let errors = self.errors.take().map(|e| { - e.into_iter() - .map(|e| e.to_string()) - .collect::>() - }); - - let assistant_meta = AssistantMeta { - provider_id, - provider_options: Some(provider_options), - tool_calls, - usage, - errors, - partial: cancelled.then_some(true), - }; - db_service - .save_message(NewChatRsMessage { - session_id: &self.session_id, - role: ChatRsMessageRole::Assistant, - content: &complete_text.unwrap_or_default(), - meta: ChatRsMessageMeta::new_assistant(assistant_meta), - }) - .await?; - - Ok(()) - } - /// Start task that pings the Redis stream every `PING_INTERVAL` seconds fn start_ping_task(&self) -> tokio::task::JoinHandle<()> { let redis = self.redis.clone(); From 706300aa9d6a18f025bd2aebf6d8a3bfa61c91df Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 23 Aug 2025 03:54:50 -0400 Subject: [PATCH 20/46] server: add stream writer tests --- server/src/provider.rs | 2 +- server/src/stream/llm_writer.rs | 269 ++++++++++++++++++++++++++++++++ 2 files changed, 270 insertions(+), 1 deletion(-) diff --git a/server/src/provider.rs b/server/src/provider.rs index a604b19..ef74474 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -77,7 +77,7 @@ pub struct LlmUsage { pub type LlmApiStream = Pin> + Send>>; /// Shared configuration for LLM provider requests -#[derive(Clone, Debug, JsonSchema, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] pub struct LlmApiProviderSharedOptions { pub model: String, pub temperature: Option, diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 041b8b8..69abaa2 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -287,3 +287,272 @@ impl LlmStreamWriter { ping_handle } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::provider::{lorem::LoremProvider, LlmApiProvider, LlmApiProviderSharedOptions}; + use fred::prelude::{Builder, ClientLike, Config, Pool}; + use std::time::Duration; + + async fn setup_redis_pool() -> Pool { + let config = + Config::from_url("redis://127.0.0.1:6379").unwrap_or_else(|_| Config::default()); + let pool = Builder::from_config(config) + .build_pool(2) + .expect("Failed to build Redis pool"); + pool.init().await.expect("Failed to connect to Redis"); + pool + } + + fn create_test_writer(redis: &Pool) -> LlmStreamWriter { + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + LlmStreamWriter::new(redis, &user_id, &session_id) + } + + #[tokio::test] + async fn test_stream_writer_basic_functionality() { + let redis = setup_redis_pool().await; + let mut writer = create_test_writer(&redis); + + // Create stream + assert!(writer.start().await.is_ok()); + assert!(writer.exists().await.unwrap()); + + // Create Lorem provider and get stream + let lorem = LoremProvider::new(); + let stream = lorem + .chat_stream(vec![], None, &LlmApiProviderSharedOptions::default()) + .await + .expect("Failed to create lorem stream"); + + // Process the stream + let (text, tool_calls, usage, errors, cancelled) = writer.process(stream).await; + + // Verify results + assert!(text.is_some()); + let text = text.unwrap(); + assert!(!text.is_empty()); + assert!(text.contains("Lorem ipsum")); + assert!(text.contains("dolor sit")); + + assert!(tool_calls.is_none()); + assert!(usage.is_none()); + assert!(errors.is_some()); // Lorem provider generates some test errors + assert!(!cancelled); + + // End stream + assert!(writer.end().await.is_ok()); + + // Stream should be deleted after end + assert!(!writer.exists().await.unwrap()); + } + + #[tokio::test] + async fn test_stream_writer_batching() { + let redis = setup_redis_pool().await; + let mut writer = create_test_writer(&redis); + + assert!(writer.start().await.is_ok()); + + // Create a custom stream with small chunks to test batching + let chunks = vec![ + "Hello", " ", "world", "!", " ", "This", " ", "is", " ", "a", " ", "test", + ]; + let chunk_stream = tokio_stream::iter(chunks.into_iter().map(|text| { + Ok(crate::provider::LlmStreamChunk { + text: Some(text.to_string()), + tool_calls: None, + usage: None, + }) + })); + + let stream: LlmApiStream = Box::pin(chunk_stream); + let (text, _, _, _, cancelled) = writer.process(stream).await; + + assert!(text.is_some()); + let text = text.unwrap(); + assert_eq!(text, "Hello world! This is a test"); + assert!(!cancelled); + + writer.end().await.ok(); + } + + #[tokio::test] + async fn test_stream_writer_error_handling() { + let redis = setup_redis_pool().await; + let mut writer = create_test_writer(&redis); + + assert!(writer.start().await.is_ok()); + + // Create a stream that produces an error + let error_stream = tokio_stream::iter(vec![ + Ok(crate::provider::LlmStreamChunk { + text: Some("Hello".to_string()), + tool_calls: None, + usage: None, + }), + Err(crate::provider::LlmError::LoremError("Test error")), + Ok(crate::provider::LlmStreamChunk { + text: Some(" World".to_string()), + tool_calls: None, + usage: None, + }), + ]); + + let stream: LlmApiStream = Box::pin(error_stream); + let (text, _, _, errors, cancelled) = writer.process(stream).await; + + assert!(text.is_some()); + let text = text.unwrap(); + assert_eq!(text, "Hello World"); + + assert!(errors.is_some()); + let errors = errors.unwrap(); + assert!(!errors.is_empty()); + assert!(errors.iter().any(|e| e.contains("Test error"))); + + assert!(!cancelled); + + writer.end().await.ok(); + } + + #[tokio::test] + async fn test_stream_writer_timeout() { + let redis = setup_redis_pool().await; + let mut writer = create_test_writer(&redis); + + assert!(writer.start().await.is_ok()); + + // Create a stream that hangs (never yields anything) + let hanging_stream = tokio_stream::pending::< + Result, + >(); + + let stream: LlmApiStream = Box::pin(hanging_stream); + + // This should timeout due to LLM_TIMEOUT + let start = std::time::Instant::now(); + let (text, _, _, errors, cancelled) = writer.process(stream).await; + let elapsed = start.elapsed(); + + // Should complete in roughly LLM_TIMEOUT duration + assert!(elapsed >= Duration::from_secs(19)); // Allow some margin + assert!(elapsed < Duration::from_secs(25)); + + assert!(text.is_none()); + assert!(errors.is_some()); + let errors = errors.unwrap(); + assert!(errors.iter().any(|e| e.contains("Timeout"))); + assert!(!cancelled); // Timeout is not considered a cancellation + + writer.end().await.ok(); + } + + #[tokio::test] + async fn test_stream_writer_cancel() { + let redis = setup_redis_pool().await; + let writer = create_test_writer(&redis); + + assert!(writer.start().await.is_ok()); + assert!(writer.exists().await.unwrap()); + + // Cancel the stream + assert!(writer.cancel().await.is_ok()); + + // Stream should be deleted after cancel + assert!(!writer.exists().await.unwrap()); + } + + #[tokio::test] + async fn test_stream_writer_usage_tracking() { + let redis = setup_redis_pool().await; + let mut writer = create_test_writer(&redis); + + assert!(writer.start().await.is_ok()); + + // Create a stream with usage information + let usage_stream = tokio_stream::iter(vec![ + Ok(crate::provider::LlmStreamChunk { + text: Some("Hello".to_string()), + tool_calls: None, + usage: Some(crate::provider::LlmUsage { + input_tokens: Some(10), + output_tokens: Some(5), + cost: Some(0.001), + }), + }), + Ok(crate::provider::LlmStreamChunk { + text: Some(" World".to_string()), + tool_calls: None, + usage: Some(crate::provider::LlmUsage { + input_tokens: None, // Should not override + output_tokens: Some(7), // Should update + cost: Some(0.002), // Should update + }), + }), + ]); + + let stream: LlmApiStream = Box::pin(usage_stream); + let (text, _, usage, _, cancelled) = writer.process(stream).await; + + assert!(text.is_some()); + assert_eq!(text.unwrap(), "Hello World"); + + assert!(usage.is_some()); + let usage = usage.unwrap(); + assert_eq!(usage.input_tokens, Some(10)); + assert_eq!(usage.output_tokens, Some(7)); + assert_eq!(usage.cost, Some(0.002)); + + assert!(!cancelled); + + writer.end().await.ok(); + } + + #[tokio::test] + async fn test_redis_stream_entries() { + let redis = setup_redis_pool().await; + let mut writer = create_test_writer(&redis); + let key = writer.key.clone(); + + assert!(writer.start().await.is_ok()); + + // Verify start event was written + let entries: Vec<(String, HashMap)> = redis + .xrange(&key, "-", "+", None) + .await + .expect("Failed to read stream"); + + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].1.get("type"), Some(&"start".to_string())); + + // Create a simple stream + let simple_stream = tokio_stream::iter(vec![Ok(crate::provider::LlmStreamChunk { + text: Some("Test chunk".to_string()), + tool_calls: None, + usage: None, + })]); + + let stream: LlmApiStream = Box::pin(simple_stream); + writer.process(stream).await; + writer.flush_chunk().await.ok(); + + // Should have start + text entries (ping task may add more) + let final_entries: Vec<(String, HashMap)> = redis + .xrange(&key, "-", "+", None) + .await + .expect("Failed to read stream"); + + assert!(final_entries.len() >= 2); + + // Check that we have at least a text entry + let has_text = final_entries + .iter() + .any(|(_, data)| data.get("type") == Some(&"text".to_string())); + assert!(has_text); + + writer.end().await.ok(); + } +} From 4403fbb5518f9b7ea01bfce50a5faef825cce4a1 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 23 Aug 2025 04:32:18 -0400 Subject: [PATCH 21/46] server: SSE spec fix --- server/src/stream/reader.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/stream/reader.rs b/server/src/stream/reader.rs index 809d7bb..51af49b 100644 --- a/server/src/stream/reader.rs +++ b/server/src/stream/reader.rs @@ -139,7 +139,7 @@ fn convert_redis_event_to_sse((id, event): (String, HashMap)) -> for (key, value) in event { match key.as_str() { "type" => r#type = Some(value), - "data" => data = Some(value), + "data" => data = Some(format!(" {value}")), // SSE spec: add space before data _ => {} } } From 461706f439c45609a86aee1299d5b8e3645bbef8 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 23 Aug 2025 04:57:30 -0400 Subject: [PATCH 22/46] server: add support for SSE Last-Event-ID header --- server/src/api/chat.rs | 8 +++++--- server/src/stream.rs | 30 ++++++++++++++++++++++++++++++ server/src/stream/reader.rs | 6 ++++-- 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index 642bde3..dc1219f 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -27,7 +27,7 @@ use crate::{ }, errors::ApiError, provider::{build_llm_provider_api, LlmApiProviderSharedOptions, LlmError}, - stream::{LlmStreamWriter, SseStreamReader}, + stream::{LastEventId, LlmStreamWriter, SseStreamReader}, tools::{get_llm_tools_from_input, SendChatToolInput}, utils::{generate_title, Encryptor}, }; @@ -207,12 +207,14 @@ pub async fn connect_to_chat_stream( user_id: ChatRsUserId, redis: &State, session_id: Uuid, + start_event_id: Option, ) -> Result + Send>>>, ApiError> { let stream_reader = SseStreamReader::new(&redis); // Get all previous events from the Redis stream, and return them if we're already at the end of the stream - let (prev_events, last_event_id, is_end) = - stream_reader.get_prev_events(&user_id, &session_id).await?; + let (prev_events, last_event_id, is_end) = stream_reader + .get_prev_events(&user_id, &session_id, start_event_id.as_deref()) + .await?; let prev_events_stream = stream::iter(prev_events); if is_end { return Ok(EventStream::from(prev_events_stream.boxed())); diff --git a/server/src/stream.rs b/server/src/stream.rs index 6881fc9..ca278a3 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -4,6 +4,13 @@ mod reader; pub use llm_writer::*; pub use reader::*; +use rocket::{ + async_trait, + http::Status, + request::{FromRequest, Outcome}, + Request, +}; +use rocket_okapi::OpenApiFromRequest; use uuid::Uuid; /// Get the key prefix for the user's chat streams in Redis @@ -15,3 +22,26 @@ fn get_chat_stream_prefix(user_id: &Uuid) -> String { fn get_chat_stream_key(user_id: &Uuid, session_id: &Uuid) -> String { format!("{}:{}", get_chat_stream_prefix(user_id), session_id) } + +/// Request guard to extract the Last-Event-ID from the request headers +#[derive(OpenApiFromRequest)] +pub struct LastEventId(String); + +impl std::ops::Deref for LastEventId { + type Target = str; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[async_trait] +impl<'r> FromRequest<'r> for LastEventId { + type Error = (); + + async fn from_request(req: &'r Request<'_>) -> Outcome { + match req.headers().get_one("Last-Event-ID") { + Some(event_id) => Outcome::Success(LastEventId(event_id.to_owned())), + None => Outcome::Error((Status::BadRequest, ())), + } + } +} diff --git a/server/src/stream/reader.rs b/server/src/stream/reader.rs index 51af49b..7c44394 100644 --- a/server/src/stream/reader.rs +++ b/server/src/stream/reader.rs @@ -49,18 +49,20 @@ impl SseStreamReader { &self, user_id: &Uuid, session_id: &Uuid, + start_event_id: Option<&str>, ) -> Result<(Vec, String, bool), LlmError> { let key = get_chat_stream_key(user_id, session_id); + let start_event_id = start_event_id.unwrap_or("0-0"); let (_, prev_events): (String, Vec<(String, HashMap)>) = self .redis - .xread::>, _, _>(None, None, &key, "0-0") + .xread::>, _, _>(None, None, &key, start_event_id) .await? .and_then(|mut streams| streams.pop()) // should only be 1 stream since we're sending 1 key in the command .ok_or(LlmError::StreamNotFound)?; let (last_event_id, is_end) = prev_events .last() .map(|(id, data)| (id.to_owned(), data.get("type").is_some_and(|t| t == "end"))) - .unwrap_or_else(|| ("0-0".into(), false)); + .unwrap_or_else(|| (start_event_id.into(), false)); let sse_events = prev_events .into_iter() .map(convert_redis_event_to_sse) From 162263b4cb7c86f7d9fee1bbe4953997d787738d Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 23 Aug 2025 19:51:06 -0400 Subject: [PATCH 23/46] web: new streaming working but needs work --- server/src/api/chat.rs | 13 +- web/package.json | 1 + web/pnpm-lock.yaml | 9 + web/src/components/Sidebar.tsx | 6 +- .../components/chat/ChatStreamingMessages.tsx | 47 +-- web/src/hooks/useChatInputState.tsx | 4 +- web/src/lib/api/chat.ts | 130 +++--- web/src/lib/api/types.d.ts | 268 ++++++++++++- web/src/lib/context/chats.ts | 52 +-- web/src/lib/context/index.ts | 7 +- web/src/lib/context/streamManager.ts | 376 +++++------------- web/src/lib/context/streamManagerData.ts | 110 +++++ web/src/lib/context/streamManagerState.ts | 204 ++++++++++ 13 files changed, 827 insertions(+), 400 deletions(-) create mode 100644 web/src/lib/context/streamManagerData.ts create mode 100644 web/src/lib/context/streamManagerState.ts diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index dc1219f..40e4526 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -71,6 +71,12 @@ pub struct SendChatInput<'a> { tools: Option, } +#[derive(JsonSchema, serde::Serialize)] +pub struct SendChatResponse { + message: &'static str, + url: String, +} + /// # Start chat stream /// Send a chat message and start the streamed assistant response. After the response /// has started, use the `//stream` endpoint to connect to the SSE stream. @@ -85,7 +91,7 @@ pub async fn send_chat_stream( http_client: &State, session_id: Uuid, mut input: Json>, -) -> Result { +) -> Result, ApiError> { let mut stream_writer = LlmStreamWriter::new(&redis, &user_id, &session_id); // Check that we aren't already streaming a response for this session @@ -196,7 +202,10 @@ pub async fn send_chat_stream( } }); - Ok(format!("Stream started at /api/chat/{}/stream", session_id)) + Ok(Json(SendChatResponse { + message: "Stream started", + url: format!("/api/chat/{}/stream", session_id), + })) } /// # Connect to chat stream diff --git a/web/package.json b/web/package.json index 9c710f0..bfffb6e 100644 --- a/web/package.json +++ b/web/package.json @@ -35,6 +35,7 @@ "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "cmdk": "^1.1.1", + "eventsource-parser": "^3.0.5", "highlight.svelte": "^0.1.3", "lowlight": "^3.3.0", "lucide-react": "^0.476.0", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index ff7e221..0e4906b 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -65,6 +65,9 @@ importers: cmdk: specifier: ^1.1.1 version: 1.1.1(@types/react-dom@19.1.6(@types/react@19.1.8))(@types/react@19.1.8)(react-dom@19.1.0(react@19.1.0))(react@19.1.0) + eventsource-parser: + specifier: ^3.0.5 + version: 3.0.5 highlight.svelte: specifier: ^0.1.3 version: 0.1.3 @@ -1613,6 +1616,10 @@ packages: estree-walker@3.0.3: resolution: {integrity: sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==} + eventsource-parser@3.0.5: + resolution: {integrity: sha512-bSRG85ZrMdmWtm7qkF9He9TNRzc/Bm99gEJMaQoHJ9E6Kv9QBbsldh2oMj7iXmYNEAVvNgvv5vPorG6W+XtBhQ==} + engines: {node: '>=20.0.0'} + expect-type@1.2.1: resolution: {integrity: sha512-/kP8CAwxzLVEeFrMm4kMmy4CCDlpipyA7MYLVrdJIkV0fYF0UaigQHRsxHiuY/GEea+bh4KSv3TIlgr+2UL6bw==} engines: {node: '>=12.0.0'} @@ -3965,6 +3972,8 @@ snapshots: dependencies: '@types/estree': 1.0.8 + eventsource-parser@3.0.5: {} + expect-type@1.2.1: {} extend@3.0.2: {} diff --git a/web/src/components/Sidebar.tsx b/web/src/components/Sidebar.tsx index 77ebb43..ed606ab 100644 --- a/web/src/components/Sidebar.tsx +++ b/web/src/components/Sidebar.tsx @@ -39,7 +39,7 @@ import { } from "@/components/ui/sidebar"; import { useCreateChatSession } from "@/lib/api/session"; import type { components } from "@/lib/api/types"; -import type { StreamedChat } from "@/lib/context"; +import { useStreamingChats } from "@/lib/context"; import { Avatar, AvatarFallback, AvatarImage } from "./ui/avatar"; import { Button } from "./ui/button"; import { @@ -53,18 +53,18 @@ import { export function AppSidebar({ sessions, user, - streamedChats, ...props }: { sessions?: components["schemas"]["ChatRsSession"][]; user?: components["schemas"]["ChatRsUser"]; - streamedChats?: Record; } & React.ComponentProps) { const location = useLocation(); const navigate = useNavigate(); const router = useRouter(); const queryClient = useQueryClient(); + const { streamedChats } = useStreamingChats(); + const { mutate: createChatSession, isPending: createChatPending } = useCreateChatSession(); const onCreateChat = () => { diff --git a/web/src/components/chat/ChatStreamingMessages.tsx b/web/src/components/chat/ChatStreamingMessages.tsx index d536999..ec405ec 100644 --- a/web/src/components/chat/ChatStreamingMessages.tsx +++ b/web/src/components/chat/ChatStreamingMessages.tsx @@ -3,7 +3,7 @@ import { useEffect } from "react"; import Markdown from "react-markdown"; import useSmoothStreaming from "@/hooks/useSmoothStreaming"; -import type { StreamedChat } from "@/lib/context"; +import type { StreamingChat } from "@/lib/context"; import { cn } from "@/lib/utils"; import { ChatBubble, @@ -14,7 +14,7 @@ import { proseAssistantClasses, proseClasses } from "./messages/proseStyles"; interface Props { sessionId: string; - currentStream?: StreamedChat; + currentStream?: StreamingChat; } /** Displays currently streaming assistant responses and errors */ @@ -26,29 +26,28 @@ export default function ChatStreamingMessages({ displayedText: streamingMessage, complete, reset, - } = useSmoothStreaming(currentStream?.content); + } = useSmoothStreaming(currentStream?.text); useEffect(() => { if (currentStream?.status === "completed") complete(); }, [currentStream?.status, complete]); useEffect(() => { - if (!currentStream?.content) reset(); - }, [currentStream?.content, reset]); + if (!currentStream?.text) reset(); + }, [currentStream?.text, reset]); useEffect(() => { if (sessionId) reset(); }, [sessionId, reset]); return ( <> - {currentStream?.status === "streaming" && - currentStream?.content === "" && ( - - } - className="animate-pulse" - /> - - - )} + {currentStream?.status === "streaming" && currentStream?.text === "" && ( + + } + className="animate-pulse" + /> + + + )} {streamingMessage && ( @@ -68,14 +67,16 @@ export default function ChatStreamingMessages({ )} - {currentStream?.error && ( - - } /> - - {currentStream.error} - - - )} + {currentStream && + currentStream.errors.length > 0 && + currentStream.errors.map((error, idx) => ( + + } /> + + {error} + + + ))} ); } diff --git a/web/src/hooks/useChatInputState.tsx b/web/src/hooks/useChatInputState.tsx index 9f5ca25..ecfae75 100644 --- a/web/src/hooks/useChatInputState.tsx +++ b/web/src/hooks/useChatInputState.tsx @@ -137,7 +137,7 @@ export const useChatInputState = ({ onSubmit({ message: inputRef.current?.value, provider_id: providerId, - provider_options: { + options: { model: modelId, temperature, max_tokens: maxTokens, @@ -162,7 +162,7 @@ export const useChatInputState = ({ } onSubmit({ provider_id: providerId, - provider_options: { + options: { model: modelId, temperature, max_tokens: maxTokens, diff --git a/web/src/lib/api/chat.ts b/web/src/lib/api/chat.ts index f57ad53..ec585d4 100644 --- a/web/src/lib/api/chat.ts +++ b/web/src/lib/api/chat.ts @@ -1,77 +1,85 @@ -import { type ReadyStateEvent, SSE, type SSEvent } from "sse.js"; +import { useQuery } from "@tanstack/react-query"; +import { EventSourceParserStream } from "eventsource-parser/stream"; -import type { components } from "./types"; +import { client } from "./client"; -/** Stream a chat via SSE, using the `eventsource` library */ -export function streamChat( +async function getCurrentStreams() { + const res = await client.GET("/chat/streams"); + if (res.error) { + throw new Error(res.error.message); + } + return res.data; +} + +export const useGetCurrentStreams = (enabled: boolean) => + useQuery({ + enabled, + queryKey: ["serverStreams"], + queryFn: getCurrentStreams, + }); + +export async function createChatStream( sessionId: string, - input: components["schemas"]["SendChatInput"], { - onPart, + onText, + onToolCall, onError, }: { - onPart: (part: string) => void; + onText: (part: string) => void; + onToolCall: (toolCall: string) => void; onError: (error: string) => void; }, ) { - const source = new SSE(`/api/chat/${sessionId}`, { - method: "POST", - payload: JSON.stringify(input), - headers: { "Content-Type": "application/json" }, + const abortController = new AbortController(); + const res = await client.GET("/chat/{session_id}/stream", { + params: { path: { session_id: sessionId } }, + parseAs: "stream", + signal: abortController.signal, }); + if (res.error) { + throw new Error(res.error.message); + } + if (!res.data) { + throw new Error("No data received"); + } return { - get readyState() { - return source.readyState; - }, - start: new Promise((resolve, reject) => { - const chatListener = (event: SSEvent) => { - onPart(event.data); - }; - const errorListener = (event: SSEvent & { responseCode?: number }) => { - console.error("Error while streaming:", event); - if (event.responseCode) { - let data: string | undefined; - try { - data = JSON.parse(event.data).message; - } catch { - data = event.data; - } - if (typeof data === "string") { - onError(data); - } else { - switch (event.responseCode) { - case 404: - onError("Not Found Error"); - break; - case 500: - onError("Internal Server Error"); - break; - default: - onError(`Error code ${event.responseCode}`); - break; - } - } - reject(); - } else { - onError( - typeof event.data === "string" ? event.data : "Unknown error", - ); - } - }; + stream: async () => { + if (!res.data) return; + const eventStream = res.data + .pipeThrough(new TextDecoderStream()) + .pipeThrough(new EventSourceParserStream()) + .getReader(); + loop: while (true) { + const { done, value } = await eventStream.read(); + if (done) break; - const endListener = (event: ReadyStateEvent) => { - if (event.readyState === SSE.CLOSED) { - source.removeEventListener("chat", chatListener); - source.removeEventListener("error", errorListener); - source.removeEventListener("readystatechange", endListener); - resolve(); + switch (value.event) { + case "text": + onText(value.data); + break; + case "error": + onError(value.data); + break; + case "tool_call": + onToolCall(value.data); + break; + case "start": + case "ping": + break; + case "end": + case "cancel": + break loop; + default: + console.warn(`Unknown event type: ${value.event}`); + break; } - }; - - source.addEventListener("chat", chatListener); - source.addEventListener("error", errorListener); - source.addEventListener("readystatechange", endListener); - }), + } + try { + abortController.abort(); + } catch (error) { + console.warn("Error closing event stream:", error); + } + }, }; } diff --git a/web/src/lib/api/types.d.ts b/web/src/lib/api/types.d.ts index 46348d5..fb49bfd 100644 --- a/web/src/lib/api/types.d.ts +++ b/web/src/lib/api/types.d.ts @@ -237,6 +237,26 @@ export interface paths { patch?: never; trace?: never; }; + "/chat/streams": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * Get chat streams + * @description Get the ongoing chat response streams + */ + get: operations["get_chat_streams"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/chat/{session_id}": { parameters: { query?: never; @@ -246,7 +266,10 @@ export interface paths { }; get?: never; put?: never; - /** @description Send a chat message and stream the response */ + /** + * Start chat stream + * @description Send a chat message and start the streamed assistant response. After the response has started, use the `//stream` endpoint to connect to the SSE stream. + */ post: operations["send_chat_stream"]; delete?: never; options?: never; @@ -254,6 +277,46 @@ export interface paths { patch?: never; trace?: never; }; + "/chat/{session_id}/stream": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * Connect to chat stream + * @description Connect to an ongoing chat stream and stream the assistant response + */ + get: operations["connect_to_chat_stream"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; + "/chat/{session_id}/cancel": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + /** + * Cancel chat stream + * @description Cancel an ongoing chat stream + */ + post: operations["cancel_chat_stream"]; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/tool/": { parameters: { query?: never; @@ -598,6 +661,8 @@ export interface components { tool_calls?: components["schemas"]["ChatRsToolCall"][] | null; /** @description Provider usage information */ usage?: components["schemas"]["LlmUsage"] | null; + /** @description Errors encountered during message generation */ + errors?: string[] | null; /** @description Whether this is a partial and/or interrupted message */ partial?: boolean | null; }; @@ -691,6 +756,9 @@ export interface components { UpdateSessionInput: { title: string; }; + GetChatStreamsResponse: { + sessions: string[]; + }; SendChatInput: { /** @description The new chat message from the user */ message?: string | null; @@ -699,8 +767,8 @@ export interface components { * @description The ID of the provider to chat with */ provider_id: number; - /** @description Provider options */ - provider_options: components["schemas"]["LlmApiProviderSharedOptions"]; + /** @description Configuration for the provider */ + options: components["schemas"]["LlmApiProviderSharedOptions"]; /** @description Configuration of tools available to the assistant */ tools?: components["schemas"]["SendChatToolInput"] | null; }; @@ -1945,6 +2013,70 @@ export interface operations { }; }; }; + get_chat_streams: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody?: never; + responses: { + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["GetChatStreamsResponse"]; + }; + }; + /** @description Bad request */ + 400: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + /** @description Authentication error */ + 401: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + /** @description Not found */ + 404: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + /** @description Incorrectly formatted */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + /** @description Internal error */ + 500: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + }; + }; send_chat_stream: { parameters: { query?: never; @@ -1959,6 +2091,72 @@ export interface operations { "application/json": components["schemas"]["SendChatInput"]; }; }; + responses: { + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "text/plain": string; + }; + }; + /** @description Bad request */ + 400: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + /** @description Authentication error */ + 401: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + /** @description Not found */ + 404: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + /** @description Incorrectly formatted */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + /** @description Internal error */ + 500: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + }; + }; + connect_to_chat_stream: { + parameters: { + query?: never; + header?: never; + path: { + session_id: string; + }; + cookie?: never; + }; + requestBody?: never; responses: { 200: { headers: { @@ -2015,6 +2213,70 @@ export interface operations { }; }; }; + cancel_chat_stream: { + parameters: { + query?: never; + header?: never; + path: { + session_id: string; + }; + cookie?: never; + }; + requestBody?: never; + responses: { + 200: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Bad request */ + 400: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + /** @description Authentication error */ + 401: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + /** @description Not found */ + 404: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + /** @description Incorrectly formatted */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + /** @description Internal error */ + 500: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["Message"]; + }; + }; + }; + }; get_all_tools: { parameters: { query?: never; diff --git a/web/src/lib/context/chats.ts b/web/src/lib/context/chats.ts index 6879caa..b787a9c 100644 --- a/web/src/lib/context/chats.ts +++ b/web/src/lib/context/chats.ts @@ -7,35 +7,39 @@ import { ChatStreamContext } from "./streamManager"; /** Hook to stream a chat, get chat stream status, etc. */ export const useStreamingChats = () => { const queryClient = useQueryClient(); - const { streamedChats, startStream } = useContext(ChatStreamContext); + const { streamedChats, startStreamWithInput } = useContext(ChatStreamContext); /** Start stream + optimistic update of user message */ const onUserSubmit = useCallback( - (sessionId: string, input: components["schemas"]["SendChatInput"]) => { - startStream(sessionId, input); - if (!input.message) return; + async ( + sessionId: string, + input: components["schemas"]["SendChatInput"], + ) => { + startStreamWithInput(sessionId, input); - queryClient.setQueryData<{ - messages: components["schemas"]["ChatRsMessage"][]; - }>(["chatSession", { sessionId }], (oldData: any) => { - if (!oldData) return {}; - return { - ...oldData, - messages: [ - ...oldData.messages, - { - id: crypto.randomUUID(), - content: input.message, - role: "User", - created_at: new Date().toISOString(), - session_id: sessionId, - meta: {}, - }, - ], - }; - }); + if (input.message) { + queryClient.setQueryData<{ + messages: components["schemas"]["ChatRsMessage"][]; + }>(["chatSession", { sessionId }], (oldData: any) => { + if (!oldData) return {}; + return { + ...oldData, + messages: [ + ...oldData.messages, + { + id: crypto.randomUUID(), + content: input.message, + role: "User", + created_at: new Date().toISOString(), + session_id: sessionId, + meta: {}, + }, + ], + }; + }); + } }, - [startStream, queryClient], + [startStreamWithInput, queryClient], ); return { diff --git a/web/src/lib/context/index.ts b/web/src/lib/context/index.ts index 79f59bd..d39d342 100644 --- a/web/src/lib/context/index.ts +++ b/web/src/lib/context/index.ts @@ -1,10 +1,13 @@ import { useStreamingChats } from "./chats"; -import type { StreamedChat, StreamedToolExecution } from "./streamManager"; +import type { + StreamedToolExecution, + StreamingChat, +} from "./streamManagerState"; import { useStreamingTools } from "./tools"; export { useStreamingTools, useStreamingChats, - type StreamedChat, + type StreamingChat, type StreamedToolExecution, }; diff --git a/web/src/lib/context/streamManager.ts b/web/src/lib/context/streamManager.ts index 95e5b1e..b7c5223 100644 --- a/web/src/lib/context/streamManager.ts +++ b/web/src/lib/context/streamManager.ts @@ -1,289 +1,106 @@ -import { useQueryClient } from "@tanstack/react-query"; -import { createContext, useCallback, useState } from "react"; - -import { streamChat } from "../api/chat"; -import { chatSessionQueryKey, recentSessionsQueryKey } from "../api/session"; -import { streamToolExecution } from "../api/tool"; -import type { components } from "../api/types"; - -export interface StreamedChat { - content: string; - error?: string; - status: "streaming" | "completed"; -} - -export interface StreamedToolExecution { - result: string; - logs: string[]; - debugLogs: string[]; - error?: string; - status: "streaming" | "completed" | "error"; -} - -const streamedToolExecutionInit = (): StreamedToolExecution => ({ - result: "", - logs: [], - debugLogs: [], - status: "streaming", -}); - -/** Manage ongoing chat streams and tool executions */ -export const useStreamManager = () => { - const [streamedChats, setStreamedChats] = useState<{ - [sessionId: string]: StreamedChat | undefined; - }>({}); - - const [streamedTools, setStreamedTools] = useState<{ - [toolCallId: string]: StreamedToolExecution | undefined; - }>({}); - - const [activeToolStreams, setActiveToolStreams] = useState<{ - [toolCallId: string]: { close: () => void } | undefined; - }>({}); - - const addChatPart = useCallback((sessionId: string, part: string) => { - setStreamedChats((prev) => ({ - ...prev, - [sessionId]: { - content: (prev?.[sessionId]?.content || "") + part, - error: prev?.[sessionId]?.error, - status: "streaming", - }, - })); - }, []); - - const addChatError = useCallback((sessionId: string, error: string) => { - setStreamedChats((prev) => ({ - ...prev, - [sessionId]: { - content: prev?.[sessionId]?.content || "", - status: "streaming", - error, - }, - })); - }, []); - - const setChatStatus = useCallback( - (sessionId: string, status: "streaming" | "completed") => { - setStreamedChats((prev) => ({ - ...prev, - [sessionId]: { - status, - content: prev?.[sessionId]?.content || "", - error: prev?.[sessionId]?.error, - }, - })); - }, - [], - ); - - const clearChat = useCallback((sessionId: string) => { - setStreamedChats((prev) => ({ - ...prev, - [sessionId]: undefined, - })); - }, []); - - const queryClient = useQueryClient(); - - const invalidateSession = useCallback( - async (sessionId: string) => { - await Promise.allSettled([ - queryClient.invalidateQueries({ - queryKey: chatSessionQueryKey(sessionId), - }), - queryClient.invalidateQueries({ - queryKey: recentSessionsQueryKey, - }), - ]); - }, - [queryClient], - ); - - /** Refetch chat session for the new assistant message */ - const refetchSessionForNewAssistantResponse = useCallback( - async (sessionId: string) => { - const retryDelay = 1000; // 1 second - try { - // Refetch chat session with retry loop - let hasNewAssistantMessage = false; - let retryCount = 0; - const maxRetries = 3; - - while (!hasNewAssistantMessage && retryCount < maxRetries) { - await invalidateSession(sessionId); - - // Check if the chat session has been updated with the new assistant response - const updatedData = queryClient.getQueryData<{ - messages: components["schemas"]["ChatRsMessage"][]; - }>(["chatSession", { sessionId }]); - hasNewAssistantMessage = - updatedData?.messages?.some( - (msg) => - msg.role === "Assistant" && - !msg.meta.assistant?.partial && - new Date(msg.created_at).getTime() > Date.now() - 5000, // Within last 5 seconds - ) || false; - - // Retry if no new assistant message - if (!hasNewAssistantMessage) { - retryCount++; - if (retryCount < maxRetries) { - await new Promise((resolve) => setTimeout(resolve, retryDelay)); - } - } - } - } catch (error) { - console.error("Error refetching chat session:", error); - await invalidateSession(sessionId); - } - }, - [invalidateSession, queryClient], - ); - - /** Refetch chat session for the new tool message */ - const refetchSessionForNewToolMessage = useCallback( - async (sessionId: string, toolCallId: string) => { - const retryDelay = 1000; // 1 second - try { - let hasNewToolMessage = false; - let retryCount = 0; - const maxRetries = 3; - - while (!hasNewToolMessage && retryCount < maxRetries) { - await invalidateSession(sessionId); - - const updatedData = queryClient.getQueryData<{ - messages: components["schemas"]["ChatRsMessage"][]; - }>(["chatSession", { sessionId }]); - hasNewToolMessage = - updatedData?.messages?.some( - (msg) => - msg.role === "Tool" && msg.meta.tool_call?.id === toolCallId, - ) || false; - - if (!hasNewToolMessage) { - retryCount++; - if (retryCount < maxRetries) { - await new Promise((resolve) => setTimeout(resolve, retryDelay)); - } - } - } - } catch (error) { - console.error("Error refetching chat session:", error); - await invalidateSession(sessionId); - } - }, - [invalidateSession, queryClient], - ); +import { createContext, useCallback, useEffect } from "react"; + +import { createChatStream } from "@/lib/api/chat"; +import { client } from "@/lib/api/client"; +import { streamToolExecution } from "@/lib/api/tool"; +import type { components } from "@/lib/api/types"; +import { useStreamManagerData } from "./streamManagerData"; +import { useStreamManagerState } from "./streamManagerState"; + +export function useStreamManager() { + const { + currentChatStreams, + initSession, + clearSession, + setSessionCompleted, + streamedTools, + addTextChunk, + addToolCallChunk, + addErrorChunk, + activeToolStreams, + addActiveToolStream, + addToolLog, + addToolDebug, + addToolResult, + addToolError, + clearTool, + setToolStatus, + } = useStreamManagerState(); + + const { + refetchSessionForNewAssistantMessage, + refetchSessionForNewToolMessage, + serverStreams, + } = useStreamManagerData(); - /** Start a new chat stream */ const startStream = useCallback( - (sessionId: string, input: components["schemas"]["SendChatInput"]) => { - clearChat(sessionId); - setChatStatus(sessionId, "streaming"); - const stream = streamChat(sessionId, input, { - onPart: (part) => { - addChatPart(sessionId, part); - }, - onError: (error) => { - addChatError(sessionId, error); - }, - }); - stream.start - .then(() => { - refetchSessionForNewAssistantResponse(sessionId).then(() => - clearChat(sessionId), - ); + async (sessionId: string) => { + clearSession(sessionId); + initSession(sessionId); + + createChatStream(sessionId, { + onText: (text) => addTextChunk(sessionId, text), + onToolCall: (toolCall) => addToolCallChunk(sessionId, toolCall), + onError: (error) => addErrorChunk(sessionId, error), + }) + .then((chatStream) => { + chatStream.stream().finally(() => { + setSessionCompleted(sessionId); + refetchSessionForNewAssistantMessage(sessionId) + .then(() => clearSession(sessionId)) + .catch((err: unknown) => { + console.error("Error refetching messages:", err); + addErrorChunk(sessionId, "Error refetching chat session."); + }); + }); }) - .catch(() => { - invalidateSession(sessionId).then(() => - setChatStatus(sessionId, "completed"), - ); + .catch((err: Error) => { + addErrorChunk(sessionId, `Error starting stream: ${err.message}`); + setSessionCompleted(sessionId); + console.error("Error starting stream:", err.message); }); }, [ - clearChat, - addChatPart, - addChatError, - setChatStatus, - invalidateSession, - refetchSessionForNewAssistantResponse, + clearSession, + setSessionCompleted, + addToolCallChunk, + addTextChunk, + addErrorChunk, + initSession, + refetchSessionForNewAssistantMessage, ], ); - /** Add tool execution result chunk */ - const addToolResult = useCallback((toolCallId: string, result: string) => { - setStreamedTools((prev) => ({ - ...prev, - [toolCallId]: { - ...(prev?.[toolCallId] || streamedToolExecutionInit()), - result: (prev?.[toolCallId]?.result || "") + result, - }, - })); - }, []); - - /** Add tool execution log */ - const addToolLog = useCallback((toolCallId: string, log: string) => { - setStreamedTools((prev) => ({ - ...prev, - [toolCallId]: { - ...(prev?.[toolCallId] || streamedToolExecutionInit()), - logs: [...(prev?.[toolCallId]?.logs || []), log], - }, - })); - }, []); - - /** Add tool execution debug log */ - const addToolDebug = useCallback((toolCallId: string, debug: string) => { - setStreamedTools((prev) => ({ - ...prev, - [toolCallId]: { - ...(prev?.[toolCallId] || streamedToolExecutionInit()), - debugLogs: [...(prev?.[toolCallId]?.debugLogs || []), debug], - }, - })); - }, []); - - /** Add tool execution error */ - const addToolError = useCallback((toolCallId: string, error: string) => { - setStreamedTools((prev) => ({ - ...prev, - [toolCallId]: { - ...(prev?.[toolCallId] || streamedToolExecutionInit()), - error, - status: "error", - }, - })); - }, []); - - /** Set tool execution status */ - const setToolStatus = useCallback( - (toolCallId: string, status: "streaming" | "completed" | "error") => { - setStreamedTools((prev) => ({ - ...prev, - [toolCallId]: { - ...(prev?.[toolCallId] || streamedToolExecutionInit()), - status, - }, - })); + // Automatically start any ongoing chat streams + useEffect(() => { + if (!serverStreams) return; + for (const sessionId of serverStreams.sessions) { + if (!currentChatStreams[sessionId]) { + startStream(sessionId); + } + } + }, [serverStreams, currentChatStreams, startStream]); + + const startStreamWithInput = useCallback( + async ( + sessionId: string, + input: components["schemas"]["SendChatInput"], + ) => { + initSession(sessionId); + const response = await client.POST("/chat/{session_id}", { + params: { path: { session_id: sessionId } }, + body: input, + }); + if (response.error) { + addErrorChunk(sessionId, response.error.message); + setSessionCompleted(sessionId); + return; + } + await startStream(sessionId); }, - [], + [initSession, addErrorChunk, startStream, setSessionCompleted], ); - /** Clear active tool execution */ - const clearTool = useCallback((toolCallId: string) => { - setStreamedTools((prev) => ({ - ...prev, - [toolCallId]: undefined, - })); - setActiveToolStreams((prev) => ({ - ...prev, - [toolCallId]: undefined, - })); - }, []); - - /** Start tool execution stream */ const startToolExecution = useCallback( (messageId: string, sessionId: string, toolCallId: string) => { const stream = streamToolExecution(messageId, toolCallId, { @@ -295,10 +112,7 @@ export const useStreamManager = () => { clearTool(toolCallId); setToolStatus(toolCallId, "streaming"); - setActiveToolStreams((prev) => ({ - ...prev, - [toolCallId]: { close: stream.close }, - })); + addActiveToolStream(toolCallId, stream.close); stream.start .then(() => setToolStatus(toolCallId, "completed")) @@ -315,6 +129,7 @@ export const useStreamManager = () => { addToolLog, addToolDebug, addToolError, + addActiveToolStream, clearTool, refetchSessionForNewToolMessage, ], @@ -336,12 +151,13 @@ export const useStreamManager = () => { return { startStream, - streamedChats, + startStreamWithInput, + streamedChats: currentChatStreams, streamedTools, startToolExecution, cancelToolExecution, }; -}; +} export const ChatStreamContext = createContext< ReturnType diff --git a/web/src/lib/context/streamManagerData.ts b/web/src/lib/context/streamManagerData.ts new file mode 100644 index 0000000..cb046f3 --- /dev/null +++ b/web/src/lib/context/streamManagerData.ts @@ -0,0 +1,110 @@ +import { useQueryClient } from "@tanstack/react-query"; +import { useCallback } from "react"; + +import { useGetCurrentStreams } from "../api/chat"; +import { chatSessionQueryKey, recentSessionsQueryKey } from "../api/session"; +import type { components } from "../api/types"; +import { useGetUser } from "../api/user"; + +export function useStreamManagerData() { + const queryClient = useQueryClient(); + const { data: user } = useGetUser(); + const { data: serverStreams } = useGetCurrentStreams(!!user); + + const invalidateSession = useCallback( + async (sessionId: string) => { + await Promise.all([ + queryClient.invalidateQueries({ + queryKey: chatSessionQueryKey(sessionId), + }), + queryClient.invalidateQueries({ + queryKey: recentSessionsQueryKey, + }), + queryClient.invalidateQueries({ + queryKey: ["serverStreams"], + }), + ]); + }, + [queryClient], + ); + + const refetchSessionForNewAssistantMessage = useCallback( + async (sessionId: string) => { + const retryDelay = 1000; // 1 second + try { + // Refetch chat session with retry loop + let hasNewAssistantMessage = false; + let retryCount = 0; + const maxRetries = 3; + + while (!hasNewAssistantMessage && retryCount < maxRetries) { + await invalidateSession(sessionId); + + // Check if the chat session has been updated with the new assistant response + const updatedData = queryClient.getQueryData<{ + messages: components["schemas"]["ChatRsMessage"][]; + }>(["chatSession", { sessionId }]); + hasNewAssistantMessage = + updatedData?.messages?.some( + (msg) => + msg.role === "Assistant" && + !msg.meta.assistant?.partial && + new Date(msg.created_at).getTime() > Date.now() - 5000, // Within last 5 seconds + ) || false; + + // Retry if no new assistant message + if (!hasNewAssistantMessage) { + retryCount++; + if (retryCount < maxRetries) { + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + } + } + } + } catch (error) { + console.error("Error refetching chat session:", error); + await invalidateSession(sessionId); + } + }, + [invalidateSession, queryClient], + ); + const refetchSessionForNewToolMessage = useCallback( + async (sessionId: string, toolCallId: string) => { + const retryDelay = 1000; // 1 second + try { + let hasNewToolMessage = false; + let retryCount = 0; + const maxRetries = 3; + + while (!hasNewToolMessage && retryCount < maxRetries) { + await invalidateSession(sessionId); + + const updatedData = queryClient.getQueryData<{ + messages: components["schemas"]["ChatRsMessage"][]; + }>(["chatSession", { sessionId }]); + hasNewToolMessage = + updatedData?.messages?.some( + (msg) => + msg.role === "Tool" && msg.meta.tool_call?.id === toolCallId, + ) || false; + + if (!hasNewToolMessage) { + retryCount++; + if (retryCount < maxRetries) { + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + } + } + } + } catch (error) { + console.error("Error refetching chat session:", error); + await invalidateSession(sessionId); + } + }, + [invalidateSession, queryClient], + ); + + return { + serverStreams, + refetchSessionForNewAssistantMessage, + refetchSessionForNewToolMessage, + }; +} diff --git a/web/src/lib/context/streamManagerState.ts b/web/src/lib/context/streamManagerState.ts new file mode 100644 index 0000000..45f865a --- /dev/null +++ b/web/src/lib/context/streamManagerState.ts @@ -0,0 +1,204 @@ +import { useCallback, useState } from "react"; + +import type { components } from "../api/types"; + +export interface StreamingChat { + text: string; + errors: string[]; + toolCalls: components["schemas"]["ChatRsToolCall"][]; + status: "streaming" | "completed"; +} +const initialChatState = (): StreamingChat => ({ + text: "", + errors: [], + toolCalls: [], + status: "streaming", +}); + +export interface StreamedToolExecution { + result: string; + logs: string[]; + debugLogs: string[]; + error?: string; + status: "streaming" | "completed" | "error"; +} +const initialToolState = (): StreamedToolExecution => ({ + result: "", + logs: [], + debugLogs: [], + status: "streaming", +}); + +export function useStreamManagerState() { + const [currentChatStreams, setCurrentChatStreams] = useState<{ + [sessionId: string]: StreamingChat | undefined; + }>({}); + + const [streamedTools, setStreamedTools] = useState<{ + [toolCallId: string]: StreamedToolExecution | undefined; + }>({}); + + const [activeToolStreams, setActiveToolStreams] = useState<{ + [toolCallId: string]: { close: () => void } | undefined; + }>({}); + + const initSession = useCallback((sessionId: string) => { + setCurrentChatStreams((prev) => ({ + ...prev, + [sessionId]: initialChatState(), + })); + }, []); + + const addTextChunk = useCallback((sessionId: string, text: string) => { + setCurrentChatStreams((prev) => ({ + ...prev, + [sessionId]: { + ...(prev[sessionId] || initialChatState()), + text: (prev[sessionId]?.text || "") + text, + }, + })); + }, []); + + const addErrorChunk = useCallback((sessionId: string, error: string) => { + setCurrentChatStreams((prev) => ({ + ...prev, + [sessionId]: { + ...(prev[sessionId] || initialChatState()), + errors: [...(prev[sessionId]?.errors || []), error], + }, + })); + }, []); + + const addToolCallChunk = useCallback( + (sessionId: string, toolCall: string) => { + setCurrentChatStreams((prev) => ({ + ...prev, + [sessionId]: { + ...(prev[sessionId] || initialChatState()), + toolCalls: [ + ...(prev[sessionId]?.toolCalls || []), + JSON.parse(toolCall), + ], + }, + })); + }, + [], + ); + + const setSessionCompleted = useCallback((sessionId: string) => { + setCurrentChatStreams((prev) => ({ + ...prev, + [sessionId]: { + ...(prev[sessionId] || initialChatState()), + status: "completed", + }, + })); + }, []); + + const clearSession = useCallback((sessionId: string) => { + setCurrentChatStreams((prev) => ({ + ...prev, + [sessionId]: undefined, + })); + }, []); + + const addActiveToolStream = useCallback( + (toolCallId: string, close: () => void) => { + setActiveToolStreams((prev) => ({ + ...prev, + [toolCallId]: { close }, + })); + }, + [], + ); + + /** Add tool execution result chunk */ + const addToolResult = useCallback((toolCallId: string, result: string) => { + setStreamedTools((prev) => ({ + ...prev, + [toolCallId]: { + ...(prev?.[toolCallId] || initialToolState()), + result: (prev?.[toolCallId]?.result || "") + result, + }, + })); + }, []); + + /** Add tool execution log */ + const addToolLog = useCallback((toolCallId: string, log: string) => { + setStreamedTools((prev) => ({ + ...prev, + [toolCallId]: { + ...(prev?.[toolCallId] || initialToolState()), + logs: [...(prev?.[toolCallId]?.logs || []), log], + }, + })); + }, []); + + /** Add tool execution debug log */ + const addToolDebug = useCallback((toolCallId: string, debug: string) => { + setStreamedTools((prev) => ({ + ...prev, + [toolCallId]: { + ...(prev?.[toolCallId] || initialToolState()), + debugLogs: [...(prev?.[toolCallId]?.debugLogs || []), debug], + }, + })); + }, []); + + /** Add tool execution error */ + const addToolError = useCallback((toolCallId: string, error: string) => { + setStreamedTools((prev) => ({ + ...prev, + [toolCallId]: { + ...(prev?.[toolCallId] || initialToolState()), + error, + status: "error", + }, + })); + }, []); + + /** Set tool execution status */ + const setToolStatus = useCallback( + (toolCallId: string, status: "streaming" | "completed" | "error") => { + setStreamedTools((prev) => ({ + ...prev, + [toolCallId]: { + ...(prev?.[toolCallId] || initialToolState()), + status, + }, + })); + }, + [], + ); + + /** Clear active tool execution */ + const clearTool = useCallback((toolCallId: string) => { + setStreamedTools((prev) => ({ + ...prev, + [toolCallId]: undefined, + })); + setActiveToolStreams((prev) => ({ + ...prev, + [toolCallId]: undefined, + })); + }, []); + + return { + currentChatStreams, + initSession, + addTextChunk, + addErrorChunk, + addToolCallChunk, + clearSession, + setSessionCompleted, + addToolLog, + addToolDebug, + addToolError, + addToolResult, + activeToolStreams, + addActiveToolStream, + clearTool, + setToolStatus, + streamedTools, + }; +} From d6f51ecceb02ed83de32303d0ef9e57226176c16 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 24 Aug 2025 00:09:37 -0400 Subject: [PATCH 24/46] server: enable exclusive Redis connection locks for reading/writing streams --- server/Cargo.lock | 4 + server/Cargo.toml | 1 + server/src/api/chat.rs | 33 ++++---- server/src/api/provider.rs | 4 +- server/src/config.rs | 2 + server/src/provider.rs | 2 +- server/src/provider/anthropic.rs | 8 +- server/src/provider/openai.rs | 4 +- server/src/provider_models.rs | 6 +- server/src/redis.rs | 140 ++++++++++++++++++++++++++++--- server/src/stream.rs | 53 +++++++++++- server/src/stream/llm_writer.rs | 118 ++++++++++++++++---------- server/src/stream/reader.rs | 32 ++----- 13 files changed, 299 insertions(+), 108 deletions(-) diff --git a/server/Cargo.lock b/server/Cargo.lock index de142b9..56c0a57 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -377,6 +377,7 @@ dependencies = [ "bollard", "chrono", "const_format", + "deadpool", "diesel", "diesel-async", "diesel-derive-enum", @@ -624,6 +625,9 @@ name = "deadpool-runtime" version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +dependencies = [ + "tokio", +] [[package]] name = "deranged" diff --git a/server/Cargo.toml b/server/Cargo.toml index 3c2bef7..121344e 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -22,6 +22,7 @@ async-stream = "0.3.6" bollard = { version = "0.19.1", features = ["ssl"] } chrono = { version = "0.4.41", features = ["serde"] } const_format = "0.2.34" +deadpool = { version = "0.12.2", features = ["rt_tokio_1"] } diesel = { version = "2.2.10", features = [ "postgres", "chrono", diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index 40e4526..2bcb95a 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -27,7 +27,11 @@ use crate::{ }, errors::ApiError, provider::{build_llm_provider_api, LlmApiProviderSharedOptions, LlmError}, - stream::{LastEventId, LlmStreamWriter, SseStreamReader}, + redis::ExclusiveRedisClient, + stream::{ + cancel_current_chat_stream, check_chat_stream_exists, get_current_chat_streams, + LastEventId, LlmStreamWriter, SseStreamReader, + }, tools::{get_llm_tools_from_input, SendChatToolInput}, utils::{generate_title, Encryptor}, }; @@ -52,10 +56,9 @@ pub struct GetChatStreamsResponse { #[get("/streams")] pub async fn get_chat_streams( user_id: ChatRsUserId, - redis: &State, + redis_pool: &State, ) -> Result, ApiError> { - let stream_reader = SseStreamReader::new(&redis); - let sessions = stream_reader.get_chat_streams(&user_id).await?; + let sessions = get_current_chat_streams(redis_pool.next(), &user_id).await?; Ok(Json(GetChatStreamsResponse { sessions })) } @@ -86,16 +89,15 @@ pub async fn send_chat_stream( user_id: ChatRsUserId, db_pool: &State, mut db: DbConnection, - redis: &State, + redis_client: ExclusiveRedisClient, + redis_pool: &State, encryptor: &State, http_client: &State, session_id: Uuid, mut input: Json>, ) -> Result, ApiError> { - let mut stream_writer = LlmStreamWriter::new(&redis, &user_id, &session_id); - // Check that we aren't already streaming a response for this session - if stream_writer.exists().await? { + if check_chat_stream_exists(&redis_client, &user_id, &session_id).await? { return Err(LlmError::AlreadyStreaming)?; } @@ -116,7 +118,7 @@ pub async fn send_chat_stream( provider.base_url.as_deref(), api_key.as_deref(), &http_client, - redis, + redis_pool.next(), )?; // Get the user's chosen tools @@ -175,6 +177,7 @@ pub async fn send_chat_stream( let provider_options = input.options.clone(); // Create the Redis stream, then spawn a task to stream and save the response + let mut stream_writer = LlmStreamWriter::new(redis_client, &user_id, &session_id); stream_writer.start().await?; tokio::spawn(async move { let (text, tool_calls, usage, errors, cancelled) = stream_writer.process(stream).await; @@ -214,11 +217,11 @@ pub async fn send_chat_stream( #[get("//stream")] pub async fn connect_to_chat_stream( user_id: ChatRsUserId, - redis: &State, + redis_client: ExclusiveRedisClient, session_id: Uuid, start_event_id: Option, ) -> Result + Send>>>, ApiError> { - let stream_reader = SseStreamReader::new(&redis); + let stream_reader = SseStreamReader::new(redis_client); // Get all previous events from the Redis stream, and return them if we're already at the end of the stream let (prev_events, last_event_id, is_end) = stream_reader @@ -249,13 +252,13 @@ pub async fn connect_to_chat_stream( #[post("//cancel")] pub async fn cancel_chat_stream( user_id: ChatRsUserId, - redis: &State, + redis_pool: &State, session_id: Uuid, ) -> Result<(), ApiError> { - let stream_writer = LlmStreamWriter::new(&redis, &user_id, &session_id); - if !stream_writer.exists().await? { + let client = redis_pool.next(); + if !check_chat_stream_exists(&client, &user_id, &session_id).await? { return Err(LlmError::StreamNotFound)?; } - stream_writer.cancel().await?; + cancel_current_chat_stream(&client, &user_id, &session_id).await?; Ok(()) } diff --git a/server/src/api/provider.rs b/server/src/api/provider.rs index dc2989c..445c849 100644 --- a/server/src/api/provider.rs +++ b/server/src/api/provider.rs @@ -53,7 +53,7 @@ async fn get_all_providers( async fn list_models( user_id: ChatRsUserId, mut db: DbConnection, - redis: &State, + redis_pool: &State, encryptor: &State, http_client: &State, provider_id: i32, @@ -70,7 +70,7 @@ async fn list_models( provider.base_url.as_deref(), api_key.as_deref(), &http_client, - &redis, + redis_pool.next(), )?; Ok(Json(provider_api.list_models().await?)) diff --git a/server/src/config.rs b/server/src/config.rs index 1476ba8..680a5d6 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -22,6 +22,8 @@ pub struct AppConfig { pub redis_url: String, /// Redis pool size (default: 4) pub redis_pool: Option, + /// Maximum number of concurrent Redis connections for streaming (default: 20) + pub max_streams: Option, } /// Get the server configuration variables from Rocket diff --git a/server/src/provider.rs b/server/src/provider.rs index ef74474..e4ee7a4 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -132,7 +132,7 @@ pub fn build_llm_provider_api( base_url: Option<&str>, api_key: Option<&str>, http_client: &reqwest::Client, - redis: &fred::prelude::Pool, + redis: &fred::clients::Client, ) -> Result, LlmError> { match provider_type { ChatRsProviderType::Openai => Ok(Box::new(OpenAIProvider::new( diff --git a/server/src/provider/anthropic.rs b/server/src/provider/anthropic.rs index 9239fe7..8ae37ed 100644 --- a/server/src/provider/anthropic.rs +++ b/server/src/provider/anthropic.rs @@ -21,12 +21,16 @@ const API_VERSION: &str = "2023-06-01"; #[derive(Debug, Clone)] pub struct AnthropicProvider { client: reqwest::Client, - redis: fred::prelude::Pool, + redis: fred::clients::Client, api_key: String, } impl AnthropicProvider { - pub fn new(http_client: &reqwest::Client, redis: &fred::prelude::Pool, api_key: &str) -> Self { + pub fn new( + http_client: &reqwest::Client, + redis: &fred::clients::Client, + api_key: &str, + ) -> Self { Self { client: http_client.clone(), redis: redis.clone(), diff --git a/server/src/provider/openai.rs b/server/src/provider/openai.rs index d865554..fc63eba 100644 --- a/server/src/provider/openai.rs +++ b/server/src/provider/openai.rs @@ -19,7 +19,7 @@ const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1"; #[derive(Debug, Clone)] pub struct OpenAIProvider { client: reqwest::Client, - redis: fred::prelude::Pool, + redis: fred::clients::Client, api_key: String, base_url: String, } @@ -27,7 +27,7 @@ pub struct OpenAIProvider { impl OpenAIProvider { pub fn new( http_client: &reqwest::Client, - redis: &fred::prelude::Pool, + redis: &fred::clients::Client, api_key: &str, base_url: Option<&str>, ) -> Self { diff --git a/server/src/provider_models.rs b/server/src/provider_models.rs index 416c259..2b32d89 100644 --- a/server/src/provider_models.rs +++ b/server/src/provider_models.rs @@ -56,12 +56,12 @@ pub enum ModalityType { /// Service to fetch and cache LLM model list from https://models.dev pub struct ModelsDevService { - redis: fred::prelude::Pool, + redis: fred::clients::Client, http_client: reqwest::Client, } impl ModelsDevService { - pub fn new(redis: &fred::prelude::Pool, http_client: &reqwest::Client) -> Self { + pub fn new(redis: &fred::clients::Client, http_client: &reqwest::Client) -> Self { Self { redis: redis.clone(), http_client: http_client.clone(), @@ -111,7 +111,7 @@ impl ModelsDevService { cache.insert(provider_str.to_owned(), parsed_models_str); } - let pipeline = self.redis.next().pipeline(); + let pipeline = self.redis.pipeline(); let _: () = pipeline.hset(CACHE_KEY, cache).await?; let _: () = pipeline.expire(CACHE_KEY, CACHE_TTL, None).await?; let _: () = pipeline.all().await?; diff --git a/server/src/redis.rs b/server/src/redis.rs index c75a61f..09952d1 100644 --- a/server/src/redis.rs +++ b/server/src/redis.rs @@ -1,10 +1,24 @@ -use std::time::Duration; +use std::{ops::Deref, time::Duration}; -use fred::prelude::{Builder, ClientLike, Config, Pool, ReconnectPolicy, TcpConfig}; -use rocket::fairing::AdHoc; +use deadpool::managed; +use fred::prelude::{Builder, Client, ClientLike, ReconnectPolicy, TcpConfig}; +use rocket::{ + async_trait, + fairing::AdHoc, + http::Status, + outcome::try_outcome, + request::{FromRequest, Outcome}, + Request, State, +}; +use rocket_okapi::OpenApiFromRequest; +use tokio::sync::Mutex; use crate::config::get_app_config; +const REDIS_POOL_SIZE: usize = 4; +const MAX_EXCLUSIVE_CLIENTS: usize = 20; +const EXCLUSIVE_CLIENT_TIMEOUT: Duration = Duration::from_secs(5); + /// Fairing that sets up and initializes the Redis connection pool. pub fn setup_redis() -> AdHoc { AdHoc::on_ignite("Redis", |rocket| async { @@ -13,21 +27,42 @@ pub fn setup_redis() -> AdHoc { "Initialize Redis connection", |rocket| async { let app_config = get_app_config(&rocket); - let config = Config::from_url(&app_config.redis_url) + let config = fred::prelude::Config::from_url(&app_config.redis_url) .expect("RS_CHAT_REDIS_URL should be valid Redis URL"); - let pool = build_redis_pool(config, app_config.redis_pool.unwrap_or(4)) - .expect("Failed to build Redis pool"); + + let pool = + build_redis_pool(config, app_config.redis_pool.unwrap_or(REDIS_POOL_SIZE)) + .expect("Failed to build static Redis pool"); pool.init().await.expect("Failed to connect to Redis"); - rocket.manage(pool) + let exclusive_manager = ExclusiveClientManager::new(pool.clone()); + let exclusive_pool: ExclusiveClientPool = + managed::Pool::builder(exclusive_manager) + .max_size(app_config.max_streams.unwrap_or(MAX_EXCLUSIVE_CLIENTS)) + .runtime(deadpool::Runtime::Tokio1) + .create_timeout(Some(EXCLUSIVE_CLIENT_TIMEOUT)) + .recycle_timeout(Some(EXCLUSIVE_CLIENT_TIMEOUT)) + .wait_timeout(Some(EXCLUSIVE_CLIENT_TIMEOUT)) + .build() + .expect("Failed to build exclusive Redis pool"); + + rocket.manage(pool).manage(exclusive_pool) }, )) .attach(AdHoc::on_shutdown("Shutdown Redis connection", |rocket| { Box::pin(async { - if let Some(pool) = rocket.state::() { - rocket::info!("Shutting down Redis connection"); + if let Some(pool) = rocket.state::() { + rocket::info!("Shutting down static Redis pool"); if let Err(err) = pool.quit().await { - rocket::error!("Failed to shutdown Redis: {}", err); + rocket::warn!("Failed to shutdown Redis: {}", err); + } + } + if let Some(exclusive_pool) = rocket.state::() { + rocket::info!("Shutting down exclusive Redis pool"); + for client in exclusive_pool.manager().clients.lock().await.iter() { + if let Err(err) = client.quit().await { + rocket::warn!("Failed to shutdown Redis client: {}", err); + } } } }) @@ -36,9 +71,9 @@ pub fn setup_redis() -> AdHoc { } pub fn build_redis_pool( - redis_config: Config, + redis_config: fred::prelude::Config, pool_size: usize, -) -> Result { +) -> Result { Builder::from_config(redis_config) .with_connection_config(|config| { config.connection_timeout = Duration::from_secs(4); @@ -55,3 +90,84 @@ pub fn build_redis_pool( }) .build_pool(pool_size) } + +/// A pool of exclusive Redis connections for long-running tasks. +pub type ExclusiveClientPool = managed::Pool; + +/// Deadpool implementation for a pool of exclusive Redis clients. +#[derive(Debug)] +pub struct ExclusiveClientManager { + pool: fred::clients::Pool, + clients: Mutex>, +} +impl ExclusiveClientManager { + pub fn new(pool: fred::clients::Pool) -> Self { + Self { + pool, + clients: Mutex::default(), + } + } +} +impl managed::Manager for ExclusiveClientManager { + type Type = Client; + type Error = fred::error::Error; + + async fn create(&self) -> Result { + let client = self.pool.next().clone_new(); + println!("Creating exclusive Redis client {}", client.id()); + client.init().await?; + self.clients.lock().await.push(client.clone()); + Ok(client) + } + async fn recycle( + &self, + client: &mut Client, + _: &managed::Metrics, + ) -> managed::RecycleResult { + println!("Recycling exclusive Redis client {}", client.id()); + if !client.is_connected() { + client.init().await?; + } + let _: () = client.ping(None).await?; + Ok(()) + } + fn detach(&self, client: &mut Self::Type) { + println!("Detaching exclusive Redis client {}", client.id()); + let client = client.clone(); + self.clients + .blocking_lock() + .retain(|c| c.id() != client.id()); + tokio::spawn(async move { + if let Err(err) = client.quit().await { + rocket::error!("Failed to disconnect Redis client: {}", err); + } + }); + } +} + +/// Request guard to get an exclusive Redis connection for long-running operations. +#[derive(Debug, OpenApiFromRequest)] +pub struct ExclusiveRedisClient(pub managed::Object); + +impl Deref for ExclusiveRedisClient { + type Target = Client; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[async_trait] +impl<'r> FromRequest<'r> for ExclusiveRedisClient { + type Error = (); + + async fn from_request(req: &'r Request<'_>) -> Outcome { + let pool = try_outcome!(req.guard::<&State>().await); + match pool.get().await { + Ok(client) => Outcome::Success(ExclusiveRedisClient(client)), + Err(err) => { + rocket::error!("Failed to initialize Redis client: {}", err); + Outcome::Error((Status::InternalServerError, ())) + } + } + } +} diff --git a/server/src/stream.rs b/server/src/stream.rs index ca278a3..dfe79f6 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -1,6 +1,13 @@ mod llm_writer; mod reader; +use std::collections::HashMap; + +use fred::{ + prelude::{KeysInterface, StreamsInterface}, + types::scan::ScanType, +}; + pub use llm_writer::*; pub use reader::*; @@ -13,14 +20,56 @@ use rocket::{ use rocket_okapi::OpenApiFromRequest; use uuid::Uuid; +use crate::provider::LlmError; + /// Get the key prefix for the user's chat streams in Redis fn get_chat_stream_prefix(user_id: &Uuid) -> String { - format!("user:{}:chat", user_id) + format!("user:{}:chat:", user_id) } /// Get the key of the chat stream in Redis for the given user and session ID fn get_chat_stream_key(user_id: &Uuid, session_id: &Uuid) -> String { - format!("{}:{}", get_chat_stream_prefix(user_id), session_id) + format!("{}{}", get_chat_stream_prefix(user_id), session_id) +} + +/// Get the ongoing chat stream sessions for a user. +pub async fn get_current_chat_streams( + redis: &fred::clients::Client, + user_id: &Uuid, +) -> Result, LlmError> { + let prefix = get_chat_stream_prefix(user_id); + let pattern = format!("{}*", prefix); + let (_, keys): (String, Vec) = redis + .scan_page("0", &pattern, Some(20), Some(ScanType::Stream)) + .await?; + Ok(keys + .into_iter() + .filter_map(|key| Some(key.strip_prefix(&prefix)?.to_string())) + .collect()) +} + +/// Check if the chat stream exists. +pub async fn check_chat_stream_exists( + redis: &fred::clients::Client, + user_id: &Uuid, + session_id: &Uuid, +) -> Result { + let key = get_chat_stream_key(user_id, session_id); + let first_entry: Option<()> = redis.xread(Some(1), None, &key, "0-0").await?; + Ok(first_entry.is_some()) +} + +/// Cancel a stream by adding a `cancel` event to the stream and then deleting it from Redis +/// (not using a pipeline since we need to ensure the `cancel` event is processed before deleting the stream). +pub async fn cancel_current_chat_stream( + redis: &fred::clients::Client, + user_id: &Uuid, + session_id: &Uuid, +) -> Result<(), fred::prelude::Error> { + let key = get_chat_stream_key(user_id, session_id); + let entry: HashMap = RedisStreamChunk::Cancel.into(); + let _: () = redis.xadd(&key, true, None, "*", entry).await?; + redis.del(&key).await } /// Request guard to extract the Last-Event-ID from the request headers diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 69abaa2..621649d 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -11,6 +11,7 @@ use uuid::Uuid; use crate::{ db::models::ChatRsToolCall, provider::{LlmApiStream, LlmError, LlmUsage}, + redis::ExclusiveRedisClient, stream::get_chat_stream_key, }; @@ -28,7 +29,8 @@ const PING_INTERVAL: Duration = Duration::from_secs(2); /// Utility for processing an incoming LLM response stream and writing to a Redis stream. #[derive(Debug)] pub struct LlmStreamWriter { - redis: fred::prelude::Pool, + /// Redis client with an exclusive connection. + redis: ExclusiveRedisClient, /// The key of the Redis stream. key: String, /// The current chunk of data being processed. @@ -54,7 +56,7 @@ struct ChunkState { /// Chunk of the LLM response stored in the Redis stream. #[derive(Debug, Serialize)] #[serde(tag = "type", content = "data", rename_all = "snake_case")] -enum RedisStreamChunk { +pub(super) enum RedisStreamChunk { Start, Ping, Text(String), @@ -72,9 +74,9 @@ impl From for HashMap { } impl LlmStreamWriter { - pub fn new(redis: &fred::prelude::Pool, user_id: &Uuid, session_id: &Uuid) -> Self { + pub fn new(redis: ExclusiveRedisClient, user_id: &Uuid, session_id: &Uuid) -> Self { LlmStreamWriter { - redis: redis.clone(), + redis, key: get_chat_stream_key(user_id, session_id), current_chunk: ChunkState::default(), complete_text: None, @@ -84,34 +86,20 @@ impl LlmStreamWriter { } } - /// Check if the Redis stream already exists. - pub async fn exists(&self) -> Result { - let first_entry: Option<()> = self.redis.xread(Some(1), None, &self.key, "0-0").await?; - Ok(first_entry.is_some()) - } - /// Create the Redis stream and write a `start` entry. pub async fn start(&self) -> Result<(), fred::prelude::Error> { let entry: HashMap = RedisStreamChunk::Start.into(); - let pipeline = self.redis.next().pipeline(); + let pipeline = self.redis.pipeline(); let _: () = pipeline.xadd(&self.key, false, None, "*", entry).await?; let _: () = pipeline.expire(&self.key, STREAM_EXPIRE, None).await?; pipeline.all().await } - /// Cancel the current stream by adding a `cancel` event to the stream and then deleting it from Redis - /// (not using a pipeline since we need to ensure the `cancel` event is processed before deleting the stream). - pub async fn cancel(&self) -> Result<(), fred::prelude::Error> { - let entry: HashMap = RedisStreamChunk::Cancel.into(); - let _: () = self.redis.xadd(&self.key, true, None, "*", entry).await?; - self.redis.del(&self.key).await - } - /// Add an `end` event to notify clients that the stream has ended, and then /// delete the stream from Redis. pub async fn end(&self) -> Result<(), fred::prelude::Error> { let entry: HashMap = RedisStreamChunk::End.into(); - let pipeline = self.redis.next().pipeline(); + let pipeline = self.redis.pipeline(); let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; let _: () = pipeline.del(&self.key).await?; pipeline.all().await @@ -251,7 +239,7 @@ impl LlmStreamWriter { &self, entries: Vec>, ) -> Result<(), LlmError> { - let pipeline = self.redis.next().pipeline(); + let pipeline = self.redis.pipeline(); for entry in entries { let _: () = pipeline .xadd(&self.key, true, ("MAXLEN", "~", 500), "*", entry) @@ -291,34 +279,53 @@ impl LlmStreamWriter { #[cfg(test)] mod tests { use super::*; - use crate::provider::{lorem::LoremProvider, LlmApiProvider, LlmApiProviderSharedOptions}; - use fred::prelude::{Builder, ClientLike, Config, Pool}; + use crate::{ + provider::{lorem::LoremProvider, LlmApiProvider, LlmApiProviderSharedOptions}, + redis::{ExclusiveClientManager, ExclusiveClientPool}, + stream::{cancel_current_chat_stream, check_chat_stream_exists}, + }; + use fred::prelude::{Builder, ClientLike, Config}; use std::time::Duration; - async fn setup_redis_pool() -> Pool { + async fn setup_redis_pool() -> ExclusiveClientPool { let config = Config::from_url("redis://127.0.0.1:6379").unwrap_or_else(|_| Config::default()); let pool = Builder::from_config(config) - .build_pool(2) + .build_pool(1) .expect("Failed to build Redis pool"); pool.init().await.expect("Failed to connect to Redis"); - pool + + let manager = ExclusiveClientManager::new(pool.clone()); + let deadpool: ExclusiveClientPool = deadpool::managed::Pool::builder(manager) + .max_size(3) + .build() + .unwrap(); + + deadpool } - fn create_test_writer(redis: &Pool) -> LlmStreamWriter { - let user_id = Uuid::new_v4(); - let session_id = Uuid::new_v4(); - LlmStreamWriter::new(redis, &user_id, &session_id) + async fn create_test_writer( + redis: &ExclusiveClientPool, + user_id: &Uuid, + session_id: &Uuid, + ) -> LlmStreamWriter { + let client = redis.get().await.expect("Failed to get Redis client"); + LlmStreamWriter::new(ExclusiveRedisClient(client), user_id, session_id) } #[tokio::test] async fn test_stream_writer_basic_functionality() { let redis = setup_redis_pool().await; - let mut writer = create_test_writer(&redis); + let client = redis.get().await.unwrap(); + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let mut writer = create_test_writer(&redis, &user_id, &session_id).await; // Create stream assert!(writer.start().await.is_ok()); - assert!(writer.exists().await.unwrap()); + assert!(check_chat_stream_exists(&client, &user_id, &session_id) + .await + .unwrap()); // Create Lorem provider and get stream let lorem = LoremProvider::new(); @@ -346,13 +353,17 @@ mod tests { assert!(writer.end().await.is_ok()); // Stream should be deleted after end - assert!(!writer.exists().await.unwrap()); + assert!(!check_chat_stream_exists(&client, &user_id, &session_id) + .await + .unwrap()); } #[tokio::test] async fn test_stream_writer_batching() { let redis = setup_redis_pool().await; - let mut writer = create_test_writer(&redis); + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let mut writer = create_test_writer(&redis, &user_id, &session_id).await; assert!(writer.start().await.is_ok()); @@ -382,7 +393,9 @@ mod tests { #[tokio::test] async fn test_stream_writer_error_handling() { let redis = setup_redis_pool().await; - let mut writer = create_test_writer(&redis); + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let mut writer = create_test_writer(&redis, &user_id, &session_id).await; assert!(writer.start().await.is_ok()); @@ -421,7 +434,9 @@ mod tests { #[tokio::test] async fn test_stream_writer_timeout() { let redis = setup_redis_pool().await; - let mut writer = create_test_writer(&redis); + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let mut writer = create_test_writer(&redis, &user_id, &session_id).await; assert!(writer.start().await.is_ok()); @@ -453,22 +468,33 @@ mod tests { #[tokio::test] async fn test_stream_writer_cancel() { let redis = setup_redis_pool().await; - let writer = create_test_writer(&redis); + let client = redis.get().await.unwrap(); + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let writer = create_test_writer(&redis, &user_id, &session_id).await; assert!(writer.start().await.is_ok()); - assert!(writer.exists().await.unwrap()); + assert!(check_chat_stream_exists(&client, &user_id, &session_id) + .await + .unwrap()); // Cancel the stream - assert!(writer.cancel().await.is_ok()); + assert!(cancel_current_chat_stream(&client, &user_id, &session_id) + .await + .is_ok()); // Stream should be deleted after cancel - assert!(!writer.exists().await.unwrap()); + assert!(!check_chat_stream_exists(&client, &user_id, &session_id) + .await + .unwrap()); } #[tokio::test] async fn test_stream_writer_usage_tracking() { let redis = setup_redis_pool().await; - let mut writer = create_test_writer(&redis); + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let mut writer = create_test_writer(&redis, &user_id, &session_id).await; assert!(writer.start().await.is_ok()); @@ -514,13 +540,18 @@ mod tests { #[tokio::test] async fn test_redis_stream_entries() { let redis = setup_redis_pool().await; - let mut writer = create_test_writer(&redis); + let user_id = Uuid::new_v4(); + let session_id = Uuid::new_v4(); + let mut writer = create_test_writer(&redis, &user_id, &session_id).await; let key = writer.key.clone(); assert!(writer.start().await.is_ok()); // Verify start event was written let entries: Vec<(String, HashMap)> = redis + .get() + .await + .expect("Failed to get Redis connection") .xrange(&key, "-", "+", None) .await .expect("Failed to read stream"); @@ -541,6 +572,9 @@ mod tests { // Should have start + text entries (ping task may add more) let final_entries: Vec<(String, HashMap)> = redis + .get() + .await + .expect("Failed to get Redis connection") .xrange(&key, "-", "+", None) .await .expect("Failed to read stream"); diff --git a/server/src/stream/reader.rs b/server/src/stream/reader.rs index 7c44394..c0eb870 100644 --- a/server/src/stream/reader.rs +++ b/server/src/stream/reader.rs @@ -1,45 +1,23 @@ use std::collections::HashMap; -use fred::{ - prelude::{KeysInterface, StreamsInterface}, - types::scan::ScanType, -}; +use fred::prelude::StreamsInterface; use rocket::response::stream::Event; use tokio::sync::mpsc; use uuid::Uuid; -use crate::{ - provider::LlmError, - stream::{get_chat_stream_key, get_chat_stream_prefix}, -}; +use crate::{provider::LlmError, redis::ExclusiveRedisClient, stream::get_chat_stream_key}; /// Timeout in milliseconds for the blocking `xread` command. const XREAD_BLOCK_TIMEOUT: u64 = 5_000; // 5 seconds /// Utility for reading SSE events from a Redis stream. pub struct SseStreamReader { - redis: fred::prelude::Pool, + redis: ExclusiveRedisClient, } impl SseStreamReader { - pub fn new(redis: &fred::prelude::Pool) -> Self { - Self { - redis: redis.clone(), - } - } - - /// Get the ongoing chat stream sessions for a user. - pub async fn get_chat_streams(&self, user_id: &Uuid) -> Result, LlmError> { - let prefix = get_chat_stream_prefix(user_id); - let pattern = format!("{}:*", prefix); - let (_, keys): (String, Vec) = self - .redis - .scan_page("0", &pattern, Some(20), Some(ScanType::Stream)) - .await?; - Ok(keys - .into_iter() - .filter_map(|key| Some(key.strip_prefix(&format!("{}:", prefix))?.to_string())) - .collect()) + pub fn new(redis: ExclusiveRedisClient) -> Self { + Self { redis } } /// Retrieve the previous events from the given Redis stream. From 61196f420bdadcd7b9aac640f20d00e1c3889fbe Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 24 Aug 2025 00:44:35 -0400 Subject: [PATCH 25/46] server: add task for cleaning up idle exclusive clients --- server/src/config.rs | 2 +- server/src/redis.rs | 55 ++++++++++++++++++++++++++++---------------- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/server/src/config.rs b/server/src/config.rs index 680a5d6..0fc5183 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -20,7 +20,7 @@ pub struct AppConfig { pub database_url: String, /// Redis connection URL pub redis_url: String, - /// Redis pool size (default: 4) + /// Redis static pool size (default: 4) pub redis_pool: Option, /// Maximum number of concurrent Redis connections for streaming (default: 20) pub max_streams: Option, diff --git a/server/src/redis.rs b/server/src/redis.rs index 09952d1..89d22bc 100644 --- a/server/src/redis.rs +++ b/server/src/redis.rs @@ -1,4 +1,4 @@ -use std::{ops::Deref, time::Duration}; +use std::{ops::Deref, sync::Arc, time::Duration}; use deadpool::managed; use fred::prelude::{Builder, Client, ClientLike, ReconnectPolicy, TcpConfig}; @@ -15,9 +15,16 @@ use tokio::sync::Mutex; use crate::config::get_app_config; +/// Default size of the static Redis pool. const REDIS_POOL_SIZE: usize = 4; +/// Default maximum number of concurrent exclusive clients (e.g. max concurrent streams) const MAX_EXCLUSIVE_CLIENTS: usize = 20; -const EXCLUSIVE_CLIENT_TIMEOUT: Duration = Duration::from_secs(5); +/// Timeout for connecting and executing commands. +const CLIENT_TIMEOUT: Duration = Duration::from_secs(6); +/// Interval for checking idle exclusive clients. +const IDLE_TASK_INTERVAL: Duration = Duration::from_secs(30); +/// Shut down exclusive clients after this period of inactivity. +const IDLE_TIME: Duration = Duration::from_secs(60); /// Fairing that sets up and initializes the Redis connection pool. pub fn setup_redis() -> AdHoc { @@ -30,23 +37,35 @@ pub fn setup_redis() -> AdHoc { let config = fred::prelude::Config::from_url(&app_config.redis_url) .expect("RS_CHAT_REDIS_URL should be valid Redis URL"); - let pool = + // Build and initialize the static Redis pool + let static_pool = build_redis_pool(config, app_config.redis_pool.unwrap_or(REDIS_POOL_SIZE)) .expect("Failed to build static Redis pool"); - pool.init().await.expect("Failed to connect to Redis"); + static_pool.init().await.expect("Redis connection failed"); - let exclusive_manager = ExclusiveClientManager::new(pool.clone()); + // Build and initialize the dynamic, exclusive Redis pool for long-running tasks + let exclusive_manager = ExclusiveClientManager::new(static_pool.clone()); let exclusive_pool: ExclusiveClientPool = managed::Pool::builder(exclusive_manager) .max_size(app_config.max_streams.unwrap_or(MAX_EXCLUSIVE_CLIENTS)) .runtime(deadpool::Runtime::Tokio1) - .create_timeout(Some(EXCLUSIVE_CLIENT_TIMEOUT)) - .recycle_timeout(Some(EXCLUSIVE_CLIENT_TIMEOUT)) - .wait_timeout(Some(EXCLUSIVE_CLIENT_TIMEOUT)) + .create_timeout(Some(CLIENT_TIMEOUT)) + .recycle_timeout(Some(CLIENT_TIMEOUT)) + .wait_timeout(Some(CLIENT_TIMEOUT)) .build() .expect("Failed to build exclusive Redis pool"); - rocket.manage(pool).manage(exclusive_pool) + // Spawn a task to periodically clean up idle exclusive clients + let idle_task_pool = exclusive_pool.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(IDLE_TASK_INTERVAL); + loop { + interval.tick().await; + idle_task_pool.retain(|_, metrics| metrics.last_used() < IDLE_TIME); + } + }); + + rocket.manage(static_pool).manage(exclusive_pool) }, )) .attach(AdHoc::on_shutdown("Shutdown Redis connection", |rocket| { @@ -76,8 +95,8 @@ pub fn build_redis_pool( ) -> Result { Builder::from_config(redis_config) .with_connection_config(|config| { - config.connection_timeout = Duration::from_secs(4); - config.internal_command_timeout = Duration::from_secs(6); + config.connection_timeout = CLIENT_TIMEOUT; + config.internal_command_timeout = CLIENT_TIMEOUT; config.max_command_attempts = 2; config.tcp = TcpConfig { nodelay: Some(true), @@ -86,7 +105,7 @@ pub fn build_redis_pool( }) .set_policy(ReconnectPolicy::new_linear(0, 10_000, 1000)) .with_performance_config(|config| { - config.default_command_timeout = Duration::from_secs(10); + config.default_command_timeout = CLIENT_TIMEOUT; }) .build_pool(pool_size) } @@ -98,13 +117,13 @@ pub type ExclusiveClientPool = managed::Pool; #[derive(Debug)] pub struct ExclusiveClientManager { pool: fred::clients::Pool, - clients: Mutex>, + clients: Arc>>, } impl ExclusiveClientManager { pub fn new(pool: fred::clients::Pool) -> Self { Self { pool, - clients: Mutex::default(), + clients: Arc::default(), } } } @@ -114,7 +133,6 @@ impl managed::Manager for ExclusiveClientManager { async fn create(&self) -> Result { let client = self.pool.next().clone_new(); - println!("Creating exclusive Redis client {}", client.id()); client.init().await?; self.clients.lock().await.push(client.clone()); Ok(client) @@ -124,7 +142,6 @@ impl managed::Manager for ExclusiveClientManager { client: &mut Client, _: &managed::Metrics, ) -> managed::RecycleResult { - println!("Recycling exclusive Redis client {}", client.id()); if !client.is_connected() { client.init().await?; } @@ -132,12 +149,10 @@ impl managed::Manager for ExclusiveClientManager { Ok(()) } fn detach(&self, client: &mut Self::Type) { - println!("Detaching exclusive Redis client {}", client.id()); let client = client.clone(); - self.clients - .blocking_lock() - .retain(|c| c.id() != client.id()); + let clients = self.clients.clone(); tokio::spawn(async move { + clients.lock().await.retain(|c| c.id() != client.id()); if let Err(err) = client.quit().await { rocket::error!("Failed to disconnect Redis client: {}", err); } From ed5d30915b1ccf073dc12185a8cb5ae10ffac407 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 24 Aug 2025 00:59:00 -0400 Subject: [PATCH 26/46] server: add request guard for static Redis client --- server/src/api/chat.rs | 27 +++++++++++++-------------- server/src/api/provider.rs | 5 +++-- server/src/redis.rs | 32 +++++++++++++++++++++++++------- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index 2bcb95a..b3974d7 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -27,7 +27,7 @@ use crate::{ }, errors::ApiError, provider::{build_llm_provider_api, LlmApiProviderSharedOptions, LlmError}, - redis::ExclusiveRedisClient, + redis::{ExclusiveRedisClient, RedisClient}, stream::{ cancel_current_chat_stream, check_chat_stream_exists, get_current_chat_streams, LastEventId, LlmStreamWriter, SseStreamReader, @@ -56,9 +56,9 @@ pub struct GetChatStreamsResponse { #[get("/streams")] pub async fn get_chat_streams( user_id: ChatRsUserId, - redis_pool: &State, + redis: RedisClient, ) -> Result, ApiError> { - let sessions = get_current_chat_streams(redis_pool.next(), &user_id).await?; + let sessions = get_current_chat_streams(&redis, &user_id).await?; Ok(Json(GetChatStreamsResponse { sessions })) } @@ -89,15 +89,15 @@ pub async fn send_chat_stream( user_id: ChatRsUserId, db_pool: &State, mut db: DbConnection, - redis_client: ExclusiveRedisClient, - redis_pool: &State, + redis: RedisClient, + redis_writer: ExclusiveRedisClient, encryptor: &State, http_client: &State, session_id: Uuid, mut input: Json>, ) -> Result, ApiError> { // Check that we aren't already streaming a response for this session - if check_chat_stream_exists(&redis_client, &user_id, &session_id).await? { + if check_chat_stream_exists(&redis, &user_id, &session_id).await? { return Err(LlmError::AlreadyStreaming)?; } @@ -118,7 +118,7 @@ pub async fn send_chat_stream( provider.base_url.as_deref(), api_key.as_deref(), &http_client, - redis_pool.next(), + &redis, )?; // Get the user's chosen tools @@ -177,7 +177,7 @@ pub async fn send_chat_stream( let provider_options = input.options.clone(); // Create the Redis stream, then spawn a task to stream and save the response - let mut stream_writer = LlmStreamWriter::new(redis_client, &user_id, &session_id); + let mut stream_writer = LlmStreamWriter::new(redis_writer, &user_id, &session_id); stream_writer.start().await?; tokio::spawn(async move { let (text, tool_calls, usage, errors, cancelled) = stream_writer.process(stream).await; @@ -217,11 +217,11 @@ pub async fn send_chat_stream( #[get("//stream")] pub async fn connect_to_chat_stream( user_id: ChatRsUserId, - redis_client: ExclusiveRedisClient, + redis_reader: ExclusiveRedisClient, session_id: Uuid, start_event_id: Option, ) -> Result + Send>>>, ApiError> { - let stream_reader = SseStreamReader::new(redis_client); + let stream_reader = SseStreamReader::new(redis_reader); // Get all previous events from the Redis stream, and return them if we're already at the end of the stream let (prev_events, last_event_id, is_end) = stream_reader @@ -252,13 +252,12 @@ pub async fn connect_to_chat_stream( #[post("//cancel")] pub async fn cancel_chat_stream( user_id: ChatRsUserId, - redis_pool: &State, + redis: RedisClient, session_id: Uuid, ) -> Result<(), ApiError> { - let client = redis_pool.next(); - if !check_chat_stream_exists(&client, &user_id, &session_id).await? { + if !check_chat_stream_exists(&redis, &user_id, &session_id).await? { return Err(LlmError::StreamNotFound)?; } - cancel_current_chat_stream(&client, &user_id, &session_id).await?; + cancel_current_chat_stream(&redis, &user_id, &session_id).await?; Ok(()) } diff --git a/server/src/api/provider.rs b/server/src/api/provider.rs index 445c849..e66fa02 100644 --- a/server/src/api/provider.rs +++ b/server/src/api/provider.rs @@ -18,6 +18,7 @@ use crate::{ errors::ApiError, provider::build_llm_provider_api, provider_models::LlmModel, + redis::RedisClient, utils::Encryptor, }; @@ -53,7 +54,7 @@ async fn get_all_providers( async fn list_models( user_id: ChatRsUserId, mut db: DbConnection, - redis_pool: &State, + redis: RedisClient, encryptor: &State, http_client: &State, provider_id: i32, @@ -70,7 +71,7 @@ async fn list_models( provider.base_url.as_deref(), api_key.as_deref(), &http_client, - redis_pool.next(), + &redis, )?; Ok(Json(provider_api.list_models().await?)) diff --git a/server/src/redis.rs b/server/src/redis.rs index 89d22bc..a02771f 100644 --- a/server/src/redis.rs +++ b/server/src/redis.rs @@ -6,7 +6,6 @@ use rocket::{ async_trait, fairing::AdHoc, http::Status, - outcome::try_outcome, request::{FromRequest, Outcome}, Request, State, }; @@ -110,7 +109,26 @@ pub fn build_redis_pool( .build_pool(pool_size) } -/// A pool of exclusive Redis connections for long-running tasks. +/// Request guard for getting a Redis client from the shared static pool. +#[derive(Debug, OpenApiFromRequest)] +pub struct RedisClient(Client); +impl Deref for RedisClient { + type Target = Client; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +#[async_trait] +impl<'r> FromRequest<'r> for RedisClient { + type Error = (); + async fn from_request(req: &'r Request<'_>) -> Outcome { + use fred::clients::Pool; + let pool = req.rocket().state::>().expect("Should exist"); + Outcome::Success(RedisClient(pool.next().clone())) + } +} + +/// A pool of Redis clients with exclusive connections for long-running operations. pub type ExclusiveClientPool = managed::Pool; /// Deadpool implementation for a pool of exclusive Redis clients. @@ -160,23 +178,23 @@ impl managed::Manager for ExclusiveClientManager { } } -/// Request guard to get an exclusive Redis connection for long-running operations. +/// Request guard to get a Redis client with an exclusive connection for long-running operations. #[derive(Debug, OpenApiFromRequest)] pub struct ExclusiveRedisClient(pub managed::Object); - impl Deref for ExclusiveRedisClient { type Target = Client; fn deref(&self) -> &Self::Target { &self.0 } } - #[async_trait] impl<'r> FromRequest<'r> for ExclusiveRedisClient { type Error = (); - async fn from_request(req: &'r Request<'_>) -> Outcome { - let pool = try_outcome!(req.guard::<&State>().await); + let pool = req + .rocket() + .state::>() + .expect("Should exist"); match pool.get().await { Ok(client) => Outcome::Success(ExclusiveRedisClient(client)), Err(err) => { From 91226ba93d41e59005917f2a24a067b2bb0ac79b Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 24 Aug 2025 01:07:43 -0400 Subject: [PATCH 27/46] Update redis.rs --- server/src/redis.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/server/src/redis.rs b/server/src/redis.rs index a02771f..978f51b 100644 --- a/server/src/redis.rs +++ b/server/src/redis.rs @@ -7,7 +7,7 @@ use rocket::{ fairing::AdHoc, http::Status, request::{FromRequest, Outcome}, - Request, State, + Request, }; use rocket_okapi::OpenApiFromRequest; use tokio::sync::Mutex; @@ -123,12 +123,13 @@ impl<'r> FromRequest<'r> for RedisClient { type Error = (); async fn from_request(req: &'r Request<'_>) -> Outcome { use fred::clients::Pool; - let pool = req.rocket().state::>().expect("Should exist"); + let pool = req.rocket().state::().expect("Exists"); Outcome::Success(RedisClient(pool.next().clone())) } } -/// A pool of Redis clients with exclusive connections for long-running operations. +/// A pool of Redis clients with exclusive connections for long-running operations. Will +/// be stored in Rocket's managed state. pub type ExclusiveClientPool = managed::Pool; /// Deadpool implementation for a pool of exclusive Redis clients. @@ -155,6 +156,7 @@ impl managed::Manager for ExclusiveClientManager { self.clients.lock().await.push(client.clone()); Ok(client) } + async fn recycle( &self, client: &mut Client, @@ -166,6 +168,7 @@ impl managed::Manager for ExclusiveClientManager { let _: () = client.ping(None).await?; Ok(()) } + fn detach(&self, client: &mut Self::Type) { let client = client.clone(); let clients = self.clients.clone(); @@ -191,10 +194,7 @@ impl Deref for ExclusiveRedisClient { impl<'r> FromRequest<'r> for ExclusiveRedisClient { type Error = (); async fn from_request(req: &'r Request<'_>) -> Outcome { - let pool = req - .rocket() - .state::>() - .expect("Should exist"); + let pool = req.rocket().state::().expect("Exists"); match pool.get().await { Ok(client) => Outcome::Success(ExclusiveRedisClient(client)), Err(err) => { From f39800b5210026c5a463adecb969f82d81c8345f Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 24 Aug 2025 01:43:49 -0400 Subject: [PATCH 28/46] server: add info route --- server/src/api.rs | 2 ++ server/src/api/info.rs | 57 ++++++++++++++++++++++++++++++++++++++++++ server/src/lib.rs | 1 + 3 files changed, 60 insertions(+) create mode 100644 server/src/api/info.rs diff --git a/server/src/api.rs b/server/src/api.rs index 45183b1..8033fe7 100644 --- a/server/src/api.rs +++ b/server/src/api.rs @@ -1,6 +1,7 @@ mod api_key; mod auth; mod chat; +mod info; mod provider; mod secret; mod session; @@ -9,6 +10,7 @@ mod tool; pub use api_key::get_routes as api_key_routes; pub use auth::get_routes as auth_routes; pub use chat::get_routes as chat_routes; +pub use info::get_routes as info_routes; pub use provider::get_routes as provider_routes; pub use secret::get_routes as secret_routes; pub use session::get_routes as session_routes; diff --git a/server/src/api/info.rs b/server/src/api/info.rs new file mode 100644 index 0000000..e16ea32 --- /dev/null +++ b/server/src/api/info.rs @@ -0,0 +1,57 @@ +use rocket::{get, serde::json::Json, Route, State}; +use rocket_okapi::{ + okapi::openapi3::OpenApi, openapi, openapi_get_routes_spec, settings::OpenApiSettings, +}; +use schemars::JsonSchema; +use serde::Serialize; + +use crate::{auth::ChatRsUserId, config::AppConfig, errors::ApiError, redis::ExclusiveClientPool}; + +pub fn get_routes(settings: &OpenApiSettings) -> (Vec, OpenApi) { + openapi_get_routes_spec![ + settings: get_info + ] +} + +#[derive(Debug, Serialize, JsonSchema)] +struct InfoResponse { + version: String, + url: String, + redis: RedisStats, +} + +#[derive(Debug, Serialize, JsonSchema)] +struct RedisStats { + /// Number of static connections + r#static: usize, + /// Number of current streaming connections + streaming: usize, + /// Number of available streaming connections + streaming_available: usize, + /// Maximum number of streaming connections + streaming_max: usize, +} + +/// # Get info +/// Get information about the server +#[openapi] +#[get("/")] +async fn get_info( + _user_id: ChatRsUserId, + app_config: &State, + redis_pool: &State, +) -> Result, ApiError> { + let redis_status = redis_pool.status(); + let redis_stats = RedisStats { + r#static: app_config.redis_pool.unwrap_or(4), + streaming: redis_status.size, + streaming_max: redis_status.max_size, + streaming_available: redis_status.available, + }; + + Ok(Json(InfoResponse { + version: format!("v{}", env!("CARGO_PKG_VERSION")), + url: app_config.server_address.clone(), + redis: redis_stats, + })) +} diff --git a/server/src/lib.rs b/server/src/lib.rs index 4286e4e..828ce99 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -41,6 +41,7 @@ pub fn build_rocket() -> rocket::Rocket { mount_endpoints_and_merged_docs! { server, "/api", openapi_settings, "/" => openapi_get_routes_spec![health], + "/info" => api::info_routes(&openapi_settings), "/auth" => api::auth_routes(&openapi_settings), "/provider" => api::provider_routes(&openapi_settings), "/session" => api::session_routes(&openapi_settings), From 0ab61fe51cbd4ba9b5099e33677b331dcf242022 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 24 Aug 2025 01:56:39 -0400 Subject: [PATCH 29/46] web: tweak stale time in React Query --- web/src/components/chat/ChatStreamingToolCalls.tsx | 6 +++--- web/src/lib/api/client.ts | 2 +- web/src/lib/api/provider.ts | 2 ++ web/src/lib/api/tool.ts | 1 + web/src/lib/context/index.ts | 4 ++-- web/src/lib/context/streamManager.ts | 6 +++--- web/src/lib/context/streamManagerState.ts | 6 +++--- 7 files changed, 15 insertions(+), 12 deletions(-) diff --git a/web/src/components/chat/ChatStreamingToolCalls.tsx b/web/src/components/chat/ChatStreamingToolCalls.tsx index 82645ff..fb25930 100644 --- a/web/src/components/chat/ChatStreamingToolCalls.tsx +++ b/web/src/components/chat/ChatStreamingToolCalls.tsx @@ -16,14 +16,14 @@ import { } from "@/components/ui/collapsible"; import useSmoothStreaming from "@/hooks/useSmoothStreaming"; import type { components } from "@/lib/api/types"; -import type { StreamedToolExecution } from "@/lib/context"; +import type { StreamingToolExecution } from "@/lib/context"; import { getToolFromToolCall } from "@/lib/tools"; import { cn, escapeBackticks } from "@/lib/utils"; import { useAutoScroll } from "../ui/chat/hooks/useAutoScroll"; import ChatMessageToolLogs from "./messages/ChatMessageToolLogs"; interface Props { - streamedTools: Record; + streamedTools: Record; toolCalls?: components["schemas"]["ChatRsToolCall"][]; tools?: components["schemas"]["GetAllToolsResponse"]; sessionId: string; @@ -73,7 +73,7 @@ function StreamingToolCall({ onCancel, }: { toolCall: components["schemas"]["ChatRsToolCall"]; - streamedTool: StreamedToolExecution; + streamedTool: StreamingToolExecution; tools?: components["schemas"]["GetAllToolsResponse"]; onCancel: () => void; }) { diff --git a/web/src/lib/api/client.ts b/web/src/lib/api/client.ts index 03677bc..7efc425 100644 --- a/web/src/lib/api/client.ts +++ b/web/src/lib/api/client.ts @@ -12,7 +12,7 @@ export const client = createClient({ export const queryClient = new QueryClient({ defaultOptions: { queries: { - staleTime: 30 * 1000, // 30 seconds to stale data + staleTime: 1000 * 60, // 1 minute to stale data }, }, }); diff --git a/web/src/lib/api/provider.ts b/web/src/lib/api/provider.ts index 00ba138..ddb6aeb 100644 --- a/web/src/lib/api/provider.ts +++ b/web/src/lib/api/provider.ts @@ -8,6 +8,7 @@ const queryKey = ["providerKeys"]; export const useProviders = () => useQuery({ queryKey, + staleTime: 1000 * 60 * 5, // 5 minutes queryFn: async () => { const response = await client.GET("/provider/"); if (response.error) { @@ -21,6 +22,7 @@ export const useProviderModels = (providerId?: number | null) => useQuery({ enabled: !!providerId, queryKey: ["providerModels", { providerId }], + staleTime: Infinity, queryFn: async () => { if (!providerId) return []; const response = await client.GET("/provider/{provider_id}/models", { diff --git a/web/src/lib/api/tool.ts b/web/src/lib/api/tool.ts index e5253f6..75ab095 100644 --- a/web/src/lib/api/tool.ts +++ b/web/src/lib/api/tool.ts @@ -9,6 +9,7 @@ const queryKey = ["tools"]; export const useTools = () => useQuery({ queryKey, + staleTime: 1000 * 60 * 5, // 5 minutes queryFn: async () => { const response = await client.GET("/tool/"); if (response.error) { diff --git a/web/src/lib/context/index.ts b/web/src/lib/context/index.ts index d39d342..fa3871a 100644 --- a/web/src/lib/context/index.ts +++ b/web/src/lib/context/index.ts @@ -1,7 +1,7 @@ import { useStreamingChats } from "./chats"; import type { - StreamedToolExecution, StreamingChat, + StreamingToolExecution, } from "./streamManagerState"; import { useStreamingTools } from "./tools"; @@ -9,5 +9,5 @@ export { useStreamingTools, useStreamingChats, type StreamingChat, - type StreamedToolExecution, + type StreamingToolExecution, }; diff --git a/web/src/lib/context/streamManager.ts b/web/src/lib/context/streamManager.ts index b7c5223..2fe7b67 100644 --- a/web/src/lib/context/streamManager.ts +++ b/web/src/lib/context/streamManager.ts @@ -34,7 +34,7 @@ export function useStreamManager() { } = useStreamManagerData(); const startStream = useCallback( - async (sessionId: string) => { + (sessionId: string) => { clearSession(sessionId); initSession(sessionId); @@ -71,7 +71,7 @@ export function useStreamManager() { ], ); - // Automatically start any ongoing chat streams + // Automatically start any ongoing chat streams (e.g. on browser refresh or switching tabs) useEffect(() => { if (!serverStreams) return; for (const sessionId of serverStreams.sessions) { @@ -96,7 +96,7 @@ export function useStreamManager() { setSessionCompleted(sessionId); return; } - await startStream(sessionId); + startStream(sessionId); }, [initSession, addErrorChunk, startStream, setSessionCompleted], ); diff --git a/web/src/lib/context/streamManagerState.ts b/web/src/lib/context/streamManagerState.ts index 45f865a..a2683a4 100644 --- a/web/src/lib/context/streamManagerState.ts +++ b/web/src/lib/context/streamManagerState.ts @@ -15,14 +15,14 @@ const initialChatState = (): StreamingChat => ({ status: "streaming", }); -export interface StreamedToolExecution { +export interface StreamingToolExecution { result: string; logs: string[]; debugLogs: string[]; error?: string; status: "streaming" | "completed" | "error"; } -const initialToolState = (): StreamedToolExecution => ({ +const initialToolState = (): StreamingToolExecution => ({ result: "", logs: [], debugLogs: [], @@ -35,7 +35,7 @@ export function useStreamManagerState() { }>({}); const [streamedTools, setStreamedTools] = useState<{ - [toolCallId: string]: StreamedToolExecution | undefined; + [toolCallId: string]: StreamingToolExecution | undefined; }>({}); const [activeToolStreams, setActiveToolStreams] = useState<{ From 4d2f1371d53a0e514c5cc1d64ca5bafec8e48c4d Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 24 Aug 2025 04:37:34 -0400 Subject: [PATCH 30/46] server: bump timeout for LLM response --- server/src/stream/llm_writer.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 621649d..3594e96 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -22,7 +22,7 @@ const MAX_CHUNK_SIZE: usize = 200; /// Expiration in seconds set on the Redis stream (normally, the Redis stream will be deleted before this) const STREAM_EXPIRE: i64 = 30; /// Timeout waiting for data from the LLM stream. -const LLM_TIMEOUT: Duration = Duration::from_secs(20); +const LLM_TIMEOUT: Duration = Duration::from_secs(60); /// Interval for sending ping messages to the Redis stream. const PING_INTERVAL: Duration = Duration::from_secs(2); @@ -453,8 +453,8 @@ mod tests { let elapsed = start.elapsed(); // Should complete in roughly LLM_TIMEOUT duration - assert!(elapsed >= Duration::from_secs(19)); // Allow some margin - assert!(elapsed < Duration::from_secs(25)); + assert!(elapsed >= Duration::from_secs(59)); // Allow some margin + assert!(elapsed < Duration::from_secs(65)); assert!(text.is_none()); assert!(errors.is_some()); From 931152e5da581b3c3e59fe7ec36f1db178ba94bc Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 24 Aug 2025 11:59:48 -0400 Subject: [PATCH 31/46] server: redis stream ping tweaks --- server/src/stream/llm_writer.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 3594e96..537de8a 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -24,7 +24,7 @@ const STREAM_EXPIRE: i64 = 30; /// Timeout waiting for data from the LLM stream. const LLM_TIMEOUT: Duration = Duration::from_secs(60); /// Interval for sending ping messages to the Redis stream. -const PING_INTERVAL: Duration = Duration::from_secs(2); +const PING_INTERVAL: Duration = Duration::from_secs(5); /// Utility for processing an incoming LLM response stream and writing to a Redis stream. #[derive(Debug)] @@ -245,11 +245,10 @@ impl LlmStreamWriter { .xadd(&self.key, true, ("MAXLEN", "~", 500), "*", entry) .await?; } - let _: () = pipeline.expire(&self.key, STREAM_EXPIRE, None).await?; let res: Vec = pipeline.all().await?; // Check for `nil` responses indicating the stream has been deleted/cancelled - if res.iter().any(|r| matches!(r, fred::prelude::Value::Null)) { + if res.iter().any(|r| r.is_null()) { Err(LlmError::StreamNotFound) } else { Ok(()) @@ -265,9 +264,15 @@ impl LlmStreamWriter { loop { interval.tick().await; let entry: HashMap = RedisStreamChunk::Ping.into(); - let res: Result<(), fred::error::Error> = - redis.xadd(&key, true, None, "*", entry).await; - if res.is_err() { + let pipeline = redis.pipeline(); + let _: Result<(), fred::error::Error> = + pipeline.xadd(&key, true, None, "*", entry).await; + let _: Result<(), fred::error::Error> = + pipeline.expire(&key, STREAM_EXPIRE, None).await; + let res: Result, fred::error::Error> = + pipeline.all().await; + + if res.is_err() || res.is_ok_and(|r| r.iter().any(|v| v.is_null())) { break; } } From ec28a13a1c5165672a8372e44558f98d7ee83542 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 24 Aug 2025 17:36:09 -0400 Subject: [PATCH 32/46] server: fix system info schema --- server/src/tools/system/system_info.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/server/src/tools/system/system_info.rs b/server/src/tools/system/system_info.rs index 3cb5462..370c387 100644 --- a/server/src/tools/system/system_info.rs +++ b/server/src/tools/system/system_info.rs @@ -25,6 +25,7 @@ const SERVER_URL_DESC: &str = "Get the URL of the server that this chat applicat /// Tool to get system information. #[derive(Debug, JsonSchema)] +#[serde(deny_unknown_fields)] pub struct SystemInfo {} impl SystemInfo { pub fn new() -> Self { From b4322e3df686ea4e40d437a908dd3f784dafa384 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Wed, 27 Aug 2025 17:39:06 -0400 Subject: [PATCH 33/46] server: add pending tool calls to stream --- server/src/provider.rs | 17 +++-- server/src/provider/anthropic.rs | 30 ++------ server/src/provider/lorem.rs | 8 +- server/src/provider/openai.rs | 17 +---- server/src/stream.rs | 6 +- server/src/stream/llm_writer.rs | 123 ++++++++++++++----------------- 6 files changed, 81 insertions(+), 120 deletions(-) diff --git a/server/src/provider.rs b/server/src/provider.rs index e4ee7a4..d1d8e38 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -55,12 +55,19 @@ pub enum LlmError { Redis(#[from] fred::error::Error), } +#[derive(Debug, Clone, serde::Serialize)] +pub struct LlmPendingToolCall { + pub id: String, + pub index: usize, + pub tool_name: String, +} + /// A streaming chunk of data from the LLM provider -#[derive(Default)] -pub struct LlmStreamChunk { - pub text: Option, - pub tool_calls: Option>, - pub usage: Option, +pub enum LlmStreamChunk { + Text(String), + ToolCalls(Vec), + PendingToolCall(LlmPendingToolCall), + Usage(LlmUsage), } /// Usage stats from the LLM provider diff --git a/server/src/provider/anthropic.rs b/server/src/provider/anthropic.rs index 8ae37ed..bf772fa 100644 --- a/server/src/provider/anthropic.rs +++ b/server/src/provider/anthropic.rs @@ -146,21 +146,13 @@ impl AnthropicProvider { match event { AnthropicStreamEvent::MessageStart { message } => { if let Some(usage) = message.usage { - yield Ok(LlmStreamChunk { - text: Some(String::new()), - tool_calls: None, - usage: Some(usage.into()), - }); + yield Ok(LlmStreamChunk::Usage(usage.into())); } } AnthropicStreamEvent::ContentBlockStart { content_block, index } => { match content_block { AnthropicResponseContentBlock::Text { text } => { - yield Ok(LlmStreamChunk { - text: Some(text), - tool_calls: None, - usage: None, - }); + yield Ok(LlmStreamChunk::Text(text)); } AnthropicResponseContentBlock::ToolUse { id, name } => { current_tool_calls.push(Some(AnthropicStreamToolCall { @@ -175,11 +167,7 @@ impl AnthropicProvider { AnthropicStreamEvent::ContentBlockDelta { delta, index } => { match delta { AnthropicDelta::TextDelta { text } => { - yield Ok(LlmStreamChunk { - text: Some(text), - tool_calls: None, - usage: None, - }); + yield Ok(LlmStreamChunk::Text(text)); } AnthropicDelta::InputJsonDelta { partial_json } => { if let Some(Some(tool_call)) = current_tool_calls.iter_mut().find(|tc| tc.as_ref().is_some_and(|tc| tc.index == index)) { @@ -197,22 +185,14 @@ impl AnthropicProvider { .and_then(|tc| tc.take()) .and_then(|tc| tc.convert(llm_tools)); if let Some(converted_call) = converted_call { - yield Ok(LlmStreamChunk { - text: None, - tool_calls: Some(vec![converted_call]), - usage: None, - }); + yield Ok(LlmStreamChunk::ToolCalls(vec![converted_call])); } } } } AnthropicStreamEvent::MessageDelta { usage } => { if let Some(usage) = usage { - yield Ok(LlmStreamChunk { - text: Some(String::new()), - tool_calls: None, - usage: Some(usage.into()), - }); + yield Ok(LlmStreamChunk::Usage(usage.into())); } } AnthropicStreamEvent::Error { error } => { diff --git a/server/src/provider/lorem.rs b/server/src/provider/lorem.rs index a3a8f24..85e3744 100644 --- a/server/src/provider/lorem.rs +++ b/server/src/provider/lorem.rs @@ -16,7 +16,7 @@ use crate::{ provider_models::LlmModel, }; -/// A test/dummy provider that streams 'lorem ipsum...' +/// A test/dummy provider that streams 'lorem ipsum...' and emits test errors during the stream #[derive(Debug, Clone)] pub struct LoremProvider { pub config: LoremConfig, @@ -56,11 +56,7 @@ impl Stream for LoremStream { let word = self.words[self.index]; self.index += 1; if self.index == 0 || self.index % 10 != 0 { - std::task::Poll::Ready(Some(Ok(LlmStreamChunk { - text: Some(word.to_owned()), - tool_calls: None, - usage: None, - }))) + std::task::Poll::Ready(Some(Ok(LlmStreamChunk::Text(word.to_owned())))) } else { std::task::Poll::Ready(Some(Err(LlmError::LoremError("Test error")))) } diff --git a/server/src/provider/openai.rs b/server/src/provider/openai.rs index fc63eba..432f238 100644 --- a/server/src/provider/openai.rs +++ b/server/src/provider/openai.rs @@ -122,11 +122,8 @@ impl OpenAIProvider { Ok(mut response) => { if let Some(choice) = response.choices.pop() { if let Some(delta) = choice.delta { - if delta.content.is_some() { - yield Ok(LlmStreamChunk { - text: delta.content, - ..Default::default() - }); + if let Some(text) = delta.content { + yield Ok(LlmStreamChunk::Text(text)); } if let Some(tool_calls_delta) = delta.tool_calls { @@ -145,10 +142,7 @@ impl OpenAIProvider { // Yield usage information if available if let Some(usage) = response.usage { - yield Ok(LlmStreamChunk { - usage: Some(usage.into()), - ..Default::default() - }); + yield Ok(LlmStreamChunk::Usage(usage.into())); } } Err(e) => { @@ -175,10 +169,7 @@ impl OpenAIProvider { if let Some(rs_chat_tools) = tools { if !tool_calls.is_empty() { - yield Ok(LlmStreamChunk { - tool_calls: Some(tool_calls.into_iter().filter_map(|tc| tc.convert(&rs_chat_tools)).collect()), - ..Default::default() - }); + yield Ok(LlmStreamChunk::ToolCalls(tool_calls.into_iter().filter_map(|tc| tc.convert(&rs_chat_tools)).collect())); } } diff --git a/server/src/stream.rs b/server/src/stream.rs index dfe79f6..3684e43 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -4,7 +4,7 @@ mod reader; use std::collections::HashMap; use fred::{ - prelude::{KeysInterface, StreamsInterface}, + prelude::{FredResult, KeysInterface, StreamsInterface}, types::scan::ScanType, }; @@ -53,7 +53,7 @@ pub async fn check_chat_stream_exists( redis: &fred::clients::Client, user_id: &Uuid, session_id: &Uuid, -) -> Result { +) -> FredResult { let key = get_chat_stream_key(user_id, session_id); let first_entry: Option<()> = redis.xread(Some(1), None, &key, "0-0").await?; Ok(first_entry.is_some()) @@ -65,7 +65,7 @@ pub async fn cancel_current_chat_stream( redis: &fred::clients::Client, user_id: &Uuid, session_id: &Uuid, -) -> Result<(), fred::prelude::Error> { +) -> FredResult<()> { let key = get_chat_stream_key(user_id, session_id); let entry: HashMap = RedisStreamChunk::Cancel.into(); let _: () = redis.xadd(&key, true, None, "*", entry).await?; diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 537de8a..81be000 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -3,14 +3,14 @@ use std::{ time::{Duration, Instant}, }; -use fred::prelude::{KeysInterface, StreamsInterface}; +use fred::prelude::{FredResult, KeysInterface, StreamsInterface}; use rocket::futures::StreamExt; use serde::Serialize; use uuid::Uuid; use crate::{ db::models::ChatRsToolCall, - provider::{LlmApiStream, LlmError, LlmUsage}, + provider::{LlmApiStream, LlmError, LlmPendingToolCall, LlmStreamChunk, LlmUsage}, redis::ExclusiveRedisClient, stream::get_chat_stream_key, }; @@ -50,6 +50,7 @@ pub struct LlmStreamWriter { struct ChunkState { text: Option, tool_calls: Option>, + pending_tool_calls: Option>, error: Option, } @@ -61,6 +62,7 @@ pub(super) enum RedisStreamChunk { Ping, Text(String), ToolCall(String), + PendingToolCall(String), Error(String), Cancel, End, @@ -87,7 +89,7 @@ impl LlmStreamWriter { } /// Create the Redis stream and write a `start` entry. - pub async fn start(&self) -> Result<(), fred::prelude::Error> { + pub async fn start(&self) -> FredResult<()> { let entry: HashMap = RedisStreamChunk::Start.into(); let pipeline = self.redis.pipeline(); let _: () = pipeline.xadd(&self.key, false, None, "*", entry).await?; @@ -97,7 +99,7 @@ impl LlmStreamWriter { /// Add an `end` event to notify clients that the stream has ended, and then /// delete the stream from Redis. - pub async fn end(&self) -> Result<(), fred::prelude::Error> { + pub async fn end(&self) -> FredResult<()> { let entry: HashMap = RedisStreamChunk::End.into(); let pipeline = self.redis.pipeline(); let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; @@ -123,17 +125,14 @@ impl LlmStreamWriter { let mut cancelled = false; loop { match tokio::time::timeout(LLM_TIMEOUT, stream.next()).await { - Ok(Some(Ok(chunk))) => { - if let Some(ref text) = chunk.text { - self.process_text(text); + Ok(Some(Ok(chunk))) => match chunk { + LlmStreamChunk::Text(text) => self.process_text(&text), + LlmStreamChunk::ToolCalls(tool_calls) => self.process_tool_calls(tool_calls), + LlmStreamChunk::PendingToolCall(pending_tool_call) => { + self.process_pending_tool_call(pending_tool_call) } - if let Some(tool_calls) = chunk.tool_calls { - self.process_tool_calls(tool_calls); - } - if let Some(usage_chunk) = chunk.usage { - self.process_usage(usage_chunk); - } - } + LlmStreamChunk::Usage(usage) => self.process_usage(usage), + }, Ok(Some(Err(err))) => self.process_error(err), Ok(None) => break, Err(_) => { @@ -185,6 +184,16 @@ impl LlmStreamWriter { self.tool_calls.get_or_insert_default().extend(tool_calls); } + fn process_pending_tool_call(&mut self, tool_call: LlmPendingToolCall) { + let current_chunk = self + .current_chunk + .pending_tool_calls + .get_or_insert_default(); + if !current_chunk.iter().any(|tc| tc.index == tool_call.index) { + current_chunk.push(tool_call); + } + } + fn process_usage(&mut self, usage_chunk: LlmUsage) { let usage = self.usage.get_or_insert_default(); if let Some(input_tokens) = usage_chunk.input_tokens { @@ -208,7 +217,7 @@ impl LlmStreamWriter { return true; } let text = self.current_chunk.text.as_ref(); - text.is_some_and(|t| t.len() > MAX_CHUNK_SIZE) || last_flush_time.elapsed() > FLUSH_INTERVAL + last_flush_time.elapsed() > FLUSH_INTERVAL || text.is_some_and(|t| t.len() > MAX_CHUNK_SIZE) } async fn flush_chunk(&mut self) -> Result<(), LlmError> { @@ -223,6 +232,11 @@ impl LlmStreamWriter { RedisStreamChunk::ToolCall(serde_json::to_string(&tc).unwrap_or_default()) })); } + if let Some(pending_tool_calls) = chunk_state.pending_tool_calls { + chunks.extend(pending_tool_calls.into_iter().map(|tc| { + RedisStreamChunk::PendingToolCall(serde_json::to_string(&tc).unwrap_or_default()) + })); + } if let Some(error) = chunk_state.error { chunks.push(RedisStreamChunk::Error(error)); } @@ -255,7 +269,7 @@ impl LlmStreamWriter { } } - /// Start task that pings the Redis stream every `PING_INTERVAL` seconds + /// Start task that pings the Redis stream every `PING_INTERVAL` seconds and extends the expiration time fn start_ping_task(&self) -> tokio::task::JoinHandle<()> { let redis = self.redis.clone(); let key = self.key.to_owned(); @@ -265,13 +279,9 @@ impl LlmStreamWriter { interval.tick().await; let entry: HashMap = RedisStreamChunk::Ping.into(); let pipeline = redis.pipeline(); - let _: Result<(), fred::error::Error> = - pipeline.xadd(&key, true, None, "*", entry).await; - let _: Result<(), fred::error::Error> = - pipeline.expire(&key, STREAM_EXPIRE, None).await; - let res: Result, fred::error::Error> = - pipeline.all().await; - + let _: FredResult<()> = pipeline.xadd(&key, true, None, "*", entry).await; + let _: FredResult<()> = pipeline.expire(&key, STREAM_EXPIRE, None).await; + let res: FredResult> = pipeline.all().await; if res.is_err() || res.is_ok_and(|r| r.iter().any(|v| v.is_null())) { break; } @@ -376,13 +386,11 @@ mod tests { let chunks = vec![ "Hello", " ", "world", "!", " ", "This", " ", "is", " ", "a", " ", "test", ]; - let chunk_stream = tokio_stream::iter(chunks.into_iter().map(|text| { - Ok(crate::provider::LlmStreamChunk { - text: Some(text.to_string()), - tool_calls: None, - usage: None, - }) - })); + let chunk_stream = tokio_stream::iter( + chunks + .into_iter() + .map(|text| Ok(LlmStreamChunk::Text(text.into()))), + ); let stream: LlmApiStream = Box::pin(chunk_stream); let (text, _, _, _, cancelled) = writer.process(stream).await; @@ -406,17 +414,9 @@ mod tests { // Create a stream that produces an error let error_stream = tokio_stream::iter(vec![ - Ok(crate::provider::LlmStreamChunk { - text: Some("Hello".to_string()), - tool_calls: None, - usage: None, - }), - Err(crate::provider::LlmError::LoremError("Test error")), - Ok(crate::provider::LlmStreamChunk { - text: Some(" World".to_string()), - tool_calls: None, - usage: None, - }), + Ok(LlmStreamChunk::Text("Hello".to_string())), + Err(LlmError::LoremError("Test error")), + Ok(LlmStreamChunk::Text(" World".to_string())), ]); let stream: LlmApiStream = Box::pin(error_stream); @@ -446,9 +446,7 @@ mod tests { assert!(writer.start().await.is_ok()); // Create a stream that hangs (never yields anything) - let hanging_stream = tokio_stream::pending::< - Result, - >(); + let hanging_stream = tokio_stream::pending::>(); let stream: LlmApiStream = Box::pin(hanging_stream); @@ -505,24 +503,18 @@ mod tests { // Create a stream with usage information let usage_stream = tokio_stream::iter(vec![ - Ok(crate::provider::LlmStreamChunk { - text: Some("Hello".to_string()), - tool_calls: None, - usage: Some(crate::provider::LlmUsage { - input_tokens: Some(10), - output_tokens: Some(5), - cost: Some(0.001), - }), - }), - Ok(crate::provider::LlmStreamChunk { - text: Some(" World".to_string()), - tool_calls: None, - usage: Some(crate::provider::LlmUsage { - input_tokens: None, // Should not override - output_tokens: Some(7), // Should update - cost: Some(0.002), // Should update - }), - }), + Ok(LlmStreamChunk::Text("Hello".into())), + Ok(LlmStreamChunk::Usage(LlmUsage { + input_tokens: Some(10), + output_tokens: Some(5), + cost: Some(0.001), + })), + Ok(LlmStreamChunk::Text(" World".into())), + Ok(LlmStreamChunk::Usage(LlmUsage { + input_tokens: None, // Should not override + output_tokens: Some(7), // Should update + cost: Some(0.002), // Should update + })), ]); let stream: LlmApiStream = Box::pin(usage_stream); @@ -565,12 +557,7 @@ mod tests { assert_eq!(entries[0].1.get("type"), Some(&"start".to_string())); // Create a simple stream - let simple_stream = tokio_stream::iter(vec![Ok(crate::provider::LlmStreamChunk { - text: Some("Test chunk".to_string()), - tool_calls: None, - usage: None, - })]); - + let simple_stream = tokio_stream::iter(vec![Ok(LlmStreamChunk::Text("Test chunk".into()))]); let stream: LlmApiStream = Box::pin(simple_stream); writer.process(stream).await; writer.flush_chunk().await.ok(); From 3ea269ec12536701d0340f0ee13aba2d36d65844 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Wed, 27 Aug 2025 23:23:49 -0400 Subject: [PATCH 34/46] server: add pending tool calls to anthropic and openai providers --- server/src/provider.rs | 1 - server/src/provider/anthropic.rs | 11 +++++++++-- server/src/provider/openai.rs | 14 +++++++++----- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/server/src/provider.rs b/server/src/provider.rs index d1d8e38..c47e0fa 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -57,7 +57,6 @@ pub enum LlmError { #[derive(Debug, Clone, serde::Serialize)] pub struct LlmPendingToolCall { - pub id: String, pub index: usize, pub tool_name: String, } diff --git a/server/src/provider/anthropic.rs b/server/src/provider/anthropic.rs index bf772fa..bb60c77 100644 --- a/server/src/provider/anthropic.rs +++ b/server/src/provider/anthropic.rs @@ -8,8 +8,8 @@ use serde::{Deserialize, Serialize}; use crate::{ db::models::{ChatRsMessage, ChatRsMessageRole, ChatRsToolCall}, provider::{ - LlmApiProvider, LlmApiProviderSharedOptions, LlmApiStream, LlmError, LlmStreamChunk, - LlmTool, LlmUsage, DEFAULT_MAX_TOKENS, + LlmApiProvider, LlmApiProviderSharedOptions, LlmApiStream, LlmError, LlmPendingToolCall, + LlmStreamChunk, LlmTool, LlmUsage, DEFAULT_MAX_TOKENS, }, provider_models::{LlmModel, ModelsDevService, ModelsDevServiceProvider}, }; @@ -172,6 +172,10 @@ impl AnthropicProvider { AnthropicDelta::InputJsonDelta { partial_json } => { if let Some(Some(tool_call)) = current_tool_calls.iter_mut().find(|tc| tc.as_ref().is_some_and(|tc| tc.index == index)) { tool_call.input.push_str(&partial_json); + yield Ok(LlmStreamChunk::PendingToolCall(LlmPendingToolCall { + index, + tool_name: tool_call.name.clone() + })); } } } @@ -493,8 +497,11 @@ struct AnthropicError { /// Helper struct for tracking streaming tool calls #[derive(Debug)] struct AnthropicStreamToolCall { + /// Anthropic tool call ID id: String, + /// Index of the tool call in the message index: usize, + /// Name of the tool name: String, /// Partial input parameters (JSON stringified) input: String, diff --git a/server/src/provider/openai.rs b/server/src/provider/openai.rs index 432f238..4862b80 100644 --- a/server/src/provider/openai.rs +++ b/server/src/provider/openai.rs @@ -6,8 +6,8 @@ use serde::{Deserialize, Serialize}; use crate::{ db::models::{ChatRsMessage, ChatRsMessageRole, ChatRsToolCall}, provider::{ - LlmApiProvider, LlmApiProviderSharedOptions, LlmApiStream, LlmError, LlmStreamChunk, - LlmTool, LlmUsage, + LlmApiProvider, LlmApiProviderSharedOptions, LlmApiStream, LlmError, LlmPendingToolCall, + LlmStreamChunk, LlmTool, LlmUsage, }, provider_models::{LlmModel, ModelsDevService, ModelsDevServiceProvider}, }; @@ -128,6 +128,10 @@ impl OpenAIProvider { if let Some(tool_calls_delta) = delta.tool_calls { for tool_call_delta in tool_calls_delta { + yield Ok(LlmStreamChunk::PendingToolCall(LlmPendingToolCall { + index: tool_call_delta.index, + tool_name: tool_call_delta.function.name.clone(), + })); if let Some(tc) = tool_calls.iter_mut().find(|tc| tc.index == tool_call_delta.index) { if let Some(function_arguments) = tool_call_delta.function.arguments { *tc.function.arguments.get_or_insert_default() += &function_arguments; @@ -403,7 +407,7 @@ struct OpenAIResponseDelta { #[derive(Debug, Deserialize)] struct OpenAIStreamToolCall { id: Option, - index: u32, + index: usize, function: OpenAIStreamToolCallFunction, } @@ -411,7 +415,7 @@ impl OpenAIStreamToolCall { /// Convert OpenAI tool call format to ChatRsToolCall, add tool ID fn convert(self, rs_chat_tools: &[LlmTool]) -> Option { let id = self.id?; - let tool_name = self.function.name?; + let tool_name = self.function.name; let parameters = serde_json::from_str(&self.function.arguments?).ok()?; rs_chat_tools .iter() @@ -429,7 +433,7 @@ impl OpenAIStreamToolCall { /// OpenAI streaming tool call function #[derive(Debug, Deserialize)] struct OpenAIStreamToolCallFunction { - name: Option, + name: String, arguments: Option, } From 4886220952fb411dd045ea20164a06c86eaff36e Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Wed, 27 Aug 2025 23:55:31 -0400 Subject: [PATCH 35/46] server: fix openAI tool calling --- server/src/provider/openai.rs | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/server/src/provider/openai.rs b/server/src/provider/openai.rs index 4862b80..e2193ae 100644 --- a/server/src/provider/openai.rs +++ b/server/src/provider/openai.rs @@ -128,14 +128,16 @@ impl OpenAIProvider { if let Some(tool_calls_delta) = delta.tool_calls { for tool_call_delta in tool_calls_delta { - yield Ok(LlmStreamChunk::PendingToolCall(LlmPendingToolCall { - index: tool_call_delta.index, - tool_name: tool_call_delta.function.name.clone(), - })); if let Some(tc) = tool_calls.iter_mut().find(|tc| tc.index == tool_call_delta.index) { if let Some(function_arguments) = tool_call_delta.function.arguments { *tc.function.arguments.get_or_insert_default() += &function_arguments; } + if let Some(ref tool_name) = tc.function.name { + yield Ok(LlmStreamChunk::PendingToolCall(LlmPendingToolCall { + index: tool_call_delta.index, + tool_name: tool_name.clone(), + })); + } } else { tool_calls.push(tool_call_delta); } @@ -198,7 +200,12 @@ impl LlmApiProvider for OpenAIProvider { let request = OpenAIRequest { model: &options.model, messages: openai_messages, - max_tokens: options.max_tokens, + max_tokens: (options.max_tokens.is_some() && self.base_url != OPENAI_API_BASE_URL) + .then(|| options.max_tokens.expect("already checked for Some value")), + // OpenAI official API has deprecated `max_tokens` for `max_completion_tokens` + max_completion_tokens: (options.max_tokens.is_some() + && self.base_url == OPENAI_API_BASE_URL) + .then(|| options.max_tokens.expect("already checked for Some value")), temperature: options.temperature, stream: Some(true), stream_options: Some(OpenAIStreamOptions { @@ -305,7 +312,10 @@ impl LlmApiProvider for OpenAIProvider { struct OpenAIRequest<'a> { model: &'a str, messages: Vec>, + #[serde(skip_serializing_if = "Option::is_none")] max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + max_completion_tokens: Option, temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] stream: Option, @@ -415,7 +425,7 @@ impl OpenAIStreamToolCall { /// Convert OpenAI tool call format to ChatRsToolCall, add tool ID fn convert(self, rs_chat_tools: &[LlmTool]) -> Option { let id = self.id?; - let tool_name = self.function.name; + let tool_name = self.function.name?; let parameters = serde_json::from_str(&self.function.arguments?).ok()?; rs_chat_tools .iter() @@ -433,7 +443,7 @@ impl OpenAIStreamToolCall { /// OpenAI streaming tool call function #[derive(Debug, Deserialize)] struct OpenAIStreamToolCallFunction { - name: String, + name: Option, arguments: Option, } From cf513b736958abc8e000672e4cebf01b2a1c021a Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Thu, 28 Aug 2025 05:29:14 -0400 Subject: [PATCH 36/46] server: refactor provider streams to use channels, organize and extract utils --- server/Cargo.lock | 2 - server/Cargo.toml | 2 - server/src/api/chat.rs | 8 +- server/src/db/models/chat.rs | 4 +- server/src/provider.rs | 58 +-- server/src/provider/anthropic.rs | 453 +++------------------- server/src/provider/anthropic/request.rs | 134 +++++++ server/src/provider/anthropic/response.rs | 210 ++++++++++ server/src/provider/lorem.rs | 18 +- server/src/provider/openai.rs | 383 +++--------------- server/src/provider/openai/request.rs | 129 ++++++ server/src/provider/openai/response.rs | 135 +++++++ server/src/provider/utils.rs | 35 ++ server/src/stream.rs | 6 +- server/src/stream/llm_writer.rs | 41 +- server/src/utils/generate_title.rs | 6 +- 16 files changed, 835 insertions(+), 789 deletions(-) create mode 100644 server/src/provider/anthropic/request.rs create mode 100644 server/src/provider/anthropic/response.rs create mode 100644 server/src/provider/openai/request.rs create mode 100644 server/src/provider/openai/response.rs create mode 100644 server/src/provider/utils.rs diff --git a/server/Cargo.lock b/server/Cargo.lock index 56c0a57..860b6fd 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -373,7 +373,6 @@ version = "0.6.0" dependencies = [ "aes-gcm", "astral-tokio-tar", - "async-stream", "bollard", "chrono", "const_format", @@ -399,7 +398,6 @@ dependencies = [ "serde", "serde_json", "subst", - "tempfile", "thiserror", "tokio", "tokio-stream", diff --git a/server/Cargo.toml b/server/Cargo.toml index 121344e..3185807 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -18,7 +18,6 @@ strip = true [dependencies] aes-gcm = "0.10.3" astral-tokio-tar = "0.5.2" -async-stream = "0.3.6" bollard = { version = "0.19.1", features = ["ssl"] } chrono = { version = "0.4.41", features = ["serde"] } const_format = "0.2.34" @@ -59,7 +58,6 @@ schemars = { version = "0.8.22", features = ["chrono", "uuid1"] } serde = { version = "1.0.219" } serde_json = "1.0.140" subst = { version = "0.3.8", features = ["json"] } -tempfile = "3.20.0" thiserror = "2.0.12" tokio = { version = "1.45.1" } tokio-stream = "0.1.17" diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index b3974d7..fbc44c5 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -26,7 +26,7 @@ use crate::{ DbConnection, DbPool, }, errors::ApiError, - provider::{build_llm_provider_api, LlmApiProviderSharedOptions, LlmError}, + provider::{build_llm_provider_api, LlmError, LlmProviderOptions}, redis::{ExclusiveRedisClient, RedisClient}, stream::{ cancel_current_chat_stream, check_chat_stream_exists, get_current_chat_streams, @@ -69,7 +69,7 @@ pub struct SendChatInput<'a> { /// The ID of the provider to chat with provider_id: i32, /// Configuration for the provider - options: LlmApiProviderSharedOptions, + options: LlmProviderOptions, /// Configuration of tools available to the assistant tools: Option, } @@ -176,9 +176,11 @@ pub async fn send_chat_stream( let provider_id = input.provider_id; let provider_options = input.options.clone(); - // Create the Redis stream, then spawn a task to stream and save the response + // Create the Redis stream let mut stream_writer = LlmStreamWriter::new(redis_writer, &user_id, &session_id); stream_writer.start().await?; + + // Spawn a task to stream and save the response tokio::spawn(async move { let (text, tool_calls, usage, errors, cancelled) = stream_writer.process(stream).await; let assistant_meta = AssistantMeta { diff --git a/server/src/db/models/chat.rs b/server/src/db/models/chat.rs index 60080c9..ad3db9e 100644 --- a/server/src/db/models/chat.rs +++ b/server/src/db/models/chat.rs @@ -10,7 +10,7 @@ use uuid::Uuid; use crate::{ db::models::{ChatRsExecutedToolCall, ChatRsToolCall, ChatRsUser}, - provider::{LlmApiProviderSharedOptions, LlmUsage}, + provider::{LlmProviderOptions, LlmUsage}, tools::SendChatToolInput, }; @@ -99,7 +99,7 @@ pub struct AssistantMeta { pub provider_id: i32, /// Options passed to the LLM provider #[serde(skip_serializing_if = "Option::is_none")] - pub provider_options: Option, + pub provider_options: Option, /// The tool calls requested by the assistant #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, diff --git a/server/src/provider.rs b/server/src/provider.rs index c47e0fa..fd7aacb 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -3,6 +3,7 @@ pub mod anthropic; pub mod lorem; pub mod openai; +mod utils; use std::pin::Pin; @@ -25,12 +26,8 @@ pub const DEFAULT_TEMPERATURE: f32 = 0.7; pub enum LlmError { #[error("Missing API key")] MissingApiKey, - #[error("Lorem ipsum error: {0}")] - LoremError(&'static str), - #[error("Anthropic error: {0}")] - AnthropicError(String), - #[error("OpenAI error: {0}")] - OpenAIError(String), + #[error("Provider error: {0}")] + ProviderError(String), #[error("models.dev error: {0}")] ModelsDevError(String), #[error("No chat response")] @@ -45,8 +42,6 @@ pub enum LlmError { NoStreamEvent, #[error("Client disconnected")] ClientDisconnected, - #[error("Timeout waiting for provider response")] - StreamTimeout, #[error("Encryption error")] EncryptionError, #[error("Decryption error")] @@ -55,12 +50,29 @@ pub enum LlmError { Redis(#[from] fred::error::Error), } -#[derive(Debug, Clone, serde::Serialize)] -pub struct LlmPendingToolCall { - pub index: usize, - pub tool_name: String, +/// LLM errors during streaming +#[derive(Debug, thiserror::Error)] +pub enum LlmStreamError { + #[error("Provider error: {0}")] + ProviderError(String), + #[error("Failed to parse event: {0}")] + Parsing(#[from] serde_json::Error), + #[error("Failed to decode response: {0}")] + Decoding(#[from] tokio_util::codec::LinesCodecError), + #[error("Timeout waiting for provider response")] + StreamTimeout, + #[error("Stream was cancelled")] + StreamCancelled, + #[error("Redis error: {0}")] + Redis(#[from] fred::error::Error), } +/// Shared stream response type for LLM providers +pub type LlmStream = Pin + Send>>; + +/// Shared stream chunk result type for LLM providers +pub type LlmStreamChunkResult = Result; + /// A streaming chunk of data from the LLM provider pub enum LlmStreamChunk { Text(String), @@ -69,6 +81,12 @@ pub enum LlmStreamChunk { Usage(LlmUsage), } +#[derive(Debug, Clone, serde::Serialize)] +pub struct LlmPendingToolCall { + pub index: usize, + pub tool_name: String, +} + /// Usage stats from the LLM provider #[derive(Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] pub struct LlmUsage { @@ -79,12 +97,9 @@ pub struct LlmUsage { pub cost: Option, } -/// Shared stream type for LLM providers -pub type LlmApiStream = Pin> + Send>>; - /// Shared configuration for LLM provider requests #[derive(Clone, Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] -pub struct LlmApiProviderSharedOptions { +pub struct LlmProviderOptions { pub model: String, pub temperature: Option, pub max_tokens: Option, @@ -118,15 +133,12 @@ pub trait LlmApiProvider: Send + Sync + DynClone { &self, messages: Vec, tools: Option>, - options: &LlmApiProviderSharedOptions, - ) -> Result; + options: &LlmProviderOptions, + ) -> Result; /// Submit a prompt to the provider (not streamed) - async fn prompt( - &self, - message: &str, - options: &LlmApiProviderSharedOptions, - ) -> Result; + async fn prompt(&self, message: &str, options: &LlmProviderOptions) + -> Result; /// List available models from the provider async fn list_models(&self) -> Result, LlmError>; diff --git a/server/src/provider/anthropic.rs b/server/src/provider/anthropic.rs index bb60c77..162db51 100644 --- a/server/src/provider/anthropic.rs +++ b/server/src/provider/anthropic.rs @@ -1,15 +1,27 @@ //! Anthropic LLM provider -use std::collections::HashMap; +mod request; +mod response; -use rocket::async_trait; -use serde::{Deserialize, Serialize}; +use rocket::{async_trait, futures::StreamExt}; +use tokio_stream::wrappers::ReceiverStream; use crate::{ - db::models::{ChatRsMessage, ChatRsMessageRole, ChatRsToolCall}, + db::models::ChatRsMessage, provider::{ - LlmApiProvider, LlmApiProviderSharedOptions, LlmApiStream, LlmError, LlmPendingToolCall, - LlmStreamChunk, LlmTool, LlmUsage, DEFAULT_MAX_TOKENS, + anthropic::{ + request::{ + build_anthropic_messages, build_anthropic_tools, AnthropicContentBlock, + AnthropicMessage, AnthropicRequest, + }, + response::{ + parse_anthropic_event, AnthropicResponse, AnthropicResponseContentBlock, + AnthropicStreamToolCall, + }, + }, + utils::get_sse_events, + LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmTool, LlmUsage, + DEFAULT_MAX_TOKENS, }, provider_models::{LlmModel, ModelsDevService, ModelsDevServiceProvider}, }; @@ -37,203 +49,6 @@ impl AnthropicProvider { api_key: api_key.to_string(), } } - - fn build_messages<'a>( - &self, - messages: &'a [ChatRsMessage], - ) -> (Vec>, Option<&'a str>) { - let system_prompt = messages - .iter() - .rfind(|message| message.role == ChatRsMessageRole::System) - .map(|message| message.content.as_str()); - - let anthropic_messages: Vec = messages - .iter() - .filter_map(|message| { - let role = match message.role { - ChatRsMessageRole::User => "user", - ChatRsMessageRole::Tool => "user", - ChatRsMessageRole::Assistant => "assistant", - ChatRsMessageRole::System => return None, - }; - - let mut content_blocks = Vec::new(); - - // Handle tool result messages - if message.role == ChatRsMessageRole::Tool { - if let Some(executed_call) = &message.meta.tool_call { - content_blocks.push(AnthropicContentBlock::ToolResult { - tool_use_id: &executed_call.id, - content: &message.content, - }); - } - } else { - // Handle regular text content - if !message.content.is_empty() { - content_blocks.push(AnthropicContentBlock::Text { - text: &message.content, - }); - } - // Handle tool calls in assistant messages - if let Some(tool_calls) = message - .meta - .assistant - .as_ref() - .and_then(|a| a.tool_calls.as_ref()) - { - for tool_call in tool_calls { - content_blocks.push(AnthropicContentBlock::ToolUse { - id: &tool_call.id, - name: &tool_call.tool_name, - input: &tool_call.parameters, - }); - } - } - } - - if content_blocks.is_empty() { - return None; - } - - Some(AnthropicMessage { - role, - content: content_blocks, - }) - }) - .collect(); - - (anthropic_messages, system_prompt) - } - - fn build_tools<'a>(&self, tools: &'a [LlmTool]) -> Vec> { - tools - .iter() - .map(|tool| AnthropicTool { - name: &tool.name, - description: &tool.description, - input_schema: &tool.input_schema, - }) - .collect() - } - - async fn parse_sse_stream( - &self, - mut response: reqwest::Response, - tools: Option>, - ) -> LlmApiStream { - let stream = async_stream::stream! { - let mut buffer = String::new(); - let mut current_tool_calls: Vec> = Vec::new(); - - while let Some(chunk) = response.chunk().await.transpose() { - match chunk { - Ok(bytes) => { - let text = String::from_utf8_lossy(&bytes); - buffer.push_str(&text); - - while let Some(line_end_idx) = buffer.find('\n') { - let line = buffer[..line_end_idx].trim_end_matches('\r'); - - if line.starts_with("data: ") { - let data = &line[6..]; // Remove "data: " prefix - if data.trim().is_empty() || data == "[DONE]" { - buffer.drain(..=line_end_idx); - continue; - } - - match serde_json::from_str::(data) { - Ok(event) => { - match event { - AnthropicStreamEvent::MessageStart { message } => { - if let Some(usage) = message.usage { - yield Ok(LlmStreamChunk::Usage(usage.into())); - } - } - AnthropicStreamEvent::ContentBlockStart { content_block, index } => { - match content_block { - AnthropicResponseContentBlock::Text { text } => { - yield Ok(LlmStreamChunk::Text(text)); - } - AnthropicResponseContentBlock::ToolUse { id, name } => { - current_tool_calls.push(Some(AnthropicStreamToolCall { - id, - index, - name, - input: String::with_capacity(100), - })); - } - } - } - AnthropicStreamEvent::ContentBlockDelta { delta, index } => { - match delta { - AnthropicDelta::TextDelta { text } => { - yield Ok(LlmStreamChunk::Text(text)); - } - AnthropicDelta::InputJsonDelta { partial_json } => { - if let Some(Some(tool_call)) = current_tool_calls.iter_mut().find(|tc| tc.as_ref().is_some_and(|tc| tc.index == index)) { - tool_call.input.push_str(&partial_json); - yield Ok(LlmStreamChunk::PendingToolCall(LlmPendingToolCall { - index, - tool_name: tool_call.name.clone() - })); - } - } - } - } - AnthropicStreamEvent::ContentBlockStop { index } => { - if let Some(llm_tools) = &tools { - if !current_tool_calls.is_empty() { - let converted_call = current_tool_calls - .iter_mut() - .find(|tc| tc.as_ref().is_some_and(|tc| tc.index == index)) - .and_then(|tc| tc.take()) - .and_then(|tc| tc.convert(llm_tools)); - if let Some(converted_call) = converted_call { - yield Ok(LlmStreamChunk::ToolCalls(vec![converted_call])); - } - } - } - } - AnthropicStreamEvent::MessageDelta { usage } => { - if let Some(usage) = usage { - yield Ok(LlmStreamChunk::Usage(usage.into())); - } - } - AnthropicStreamEvent::Error { error } => { - yield Err(LlmError::AnthropicError( - format!("{}: {}", error.error_type, error.message) - )); - } - _ => {} // Ignore other events (ping, message_stop) - } - } - Err(e) => { - rocket::warn!("Failed to parse SSE event: {} | Data: {}", e, data); - } - } - } else if line.starts_with("event: ") { - let event_type = &line[7..]; - rocket::debug!("SSE event type: {}", event_type); - } else if !line.trim().is_empty() && !line.starts_with(":") { - rocket::debug!("Unexpected SSE line: {}", line); - } - - buffer.drain(..=line_end_idx); - } - } - Err(e) => { - rocket::warn!("Stream chunk error: {}", e); - yield Err(LlmError::AnthropicError(format!("Stream error: {}", e))); - break; - } - } - } - - rocket::debug!("Anthropic SSE stream ended"); - }; - - Box::pin(stream) - } } #[async_trait] @@ -242,11 +57,10 @@ impl LlmApiProvider for AnthropicProvider { &self, messages: Vec, tools: Option>, - options: &LlmApiProviderSharedOptions, - ) -> Result { - let (anthropic_messages, system_prompt) = self.build_messages(&messages); - let anthropic_tools = tools.as_ref().map(|t| self.build_tools(t)); - + options: &LlmProviderOptions, + ) -> Result { + let (anthropic_messages, system_prompt) = build_anthropic_messages(&messages); + let anthropic_tools = tools.as_ref().map(|t| build_anthropic_tools(t)); let request = AnthropicRequest { model: &options.model, messages: anthropic_messages, @@ -266,24 +80,46 @@ impl LlmApiProvider for AnthropicProvider { .json(&request) .send() .await - .map_err(|e| LlmError::AnthropicError(format!("Request failed: {}", e)))?; - + .map_err(|e| LlmError::ProviderError(format!("Anthropic request failed: {}", e)))?; if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); - return Err(LlmError::AnthropicError(format!( - "API error {}: {}", + return Err(LlmError::ProviderError(format!( + "Anthropic API error {}: {}", status, error_text ))); } - Ok(self.parse_sse_stream(response, tools).await) + let (tx, rx) = tokio::sync::mpsc::channel(100); + tokio::spawn(async move { + let mut tool_calls: Vec = Vec::new(); + let mut sse_event_stream = get_sse_events(response); + 'outer: while let Some(event_result) = sse_event_stream.next().await { + match event_result { + Ok(event) => { + for chunk in parse_anthropic_event(event, &tools, &mut tool_calls) { + if tx.send(chunk).await.is_err() { + break 'outer; // receiver dropped, stop streaming + } + } + } + Err(e) => { + if tx.send(Err(e)).await.is_err() { + break 'outer; + } + } + } + } + drop(tx); + }); + + Ok(ReceiverStream::new(rx).boxed()) } async fn prompt( &self, message: &str, - options: &LlmApiProviderSharedOptions, + options: &LlmProviderOptions, ) -> Result { let request = AnthropicRequest { model: &options.model, @@ -307,13 +143,13 @@ impl LlmApiProvider for AnthropicProvider { .json(&request) .send() .await - .map_err(|e| LlmError::AnthropicError(format!("Request failed: {}", e)))?; + .map_err(|e| LlmError::ProviderError(format!("Anthropic request failed: {}", e)))?; if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); - return Err(LlmError::AnthropicError(format!( - "API error {}: {}", + return Err(LlmError::ProviderError(format!( + "Anthropic API error {}: {}", status, error_text ))); } @@ -321,8 +157,7 @@ impl LlmApiProvider for AnthropicProvider { let mut anthropic_response: AnthropicResponse = response .json() .await - .map_err(|e| LlmError::AnthropicError(format!("Failed to parse response: {}", e)))?; - + .map_err(|e| LlmError::ProviderError(format!("Failed to parse response: {}", e)))?; let text = anthropic_response .content .get_mut(0) @@ -331,7 +166,6 @@ impl LlmApiProvider for AnthropicProvider { _ => None, }) .ok_or_else(|| LlmError::NoResponse)?; - if let Some(usage) = anthropic_response.usage { let usage: LlmUsage = usage.into(); println!("Prompt usage: {:?}", usage); @@ -349,182 +183,3 @@ impl LlmApiProvider for AnthropicProvider { Ok(models) } } - -/// Anthropic API request message -#[derive(Debug, Serialize)] -struct AnthropicMessage<'a> { - role: &'a str, - content: Vec>, -} - -/// Anthropic API request body -#[derive(Debug, Serialize)] -struct AnthropicRequest<'a> { - model: &'a str, - messages: Vec>, - max_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - system: Option<&'a str>, - #[serde(skip_serializing_if = "Option::is_none")] - stream: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>>, -} - -/// Anthropic tool definition -#[derive(Debug, Serialize)] -struct AnthropicTool<'a> { - name: &'a str, - description: &'a str, - input_schema: &'a serde_json::Value, -} - -/// Anthropic content block for messages -#[derive(Debug, Serialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum AnthropicContentBlock<'a> { - Text { - text: &'a str, - }, - ToolUse { - id: &'a str, - name: &'a str, - input: &'a HashMap, - }, - ToolResult { - tool_use_id: &'a str, - content: &'a str, - }, -} - -/// Anthropic API response content block -#[derive(Debug, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum AnthropicResponseContentBlock { - Text { text: String }, - ToolUse { id: String, name: String }, -} - -/// Anthropic API response usage -#[derive(Debug, Deserialize)] -struct AnthropicUsage { - input_tokens: Option, - output_tokens: Option, -} - -impl From for LlmUsage { - fn from(usage: AnthropicUsage) -> Self { - LlmUsage { - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - cost: None, - } - } -} - -/// Anthropic API response -#[derive(Debug, Deserialize)] -struct AnthropicResponse { - content: Vec, - usage: Option, -} - -/// Anthropic stream response (message start) -#[derive(Debug, Deserialize)] -struct AnthropicStreamResponse { - // id: String, - // #[serde(rename = "type")] - // message_type: String, - // role: String, - // content: Vec, - // model: String, - // stop_reason: Option, - // stop_sequence: Option, - usage: Option, -} - -/// Anthropic streaming event types -#[derive(Debug, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum AnthropicStreamEvent { - MessageStart { - message: AnthropicStreamResponse, - }, - ContentBlockStart { - index: usize, - content_block: AnthropicResponseContentBlock, - }, - ContentBlockDelta { - index: usize, - delta: AnthropicDelta, - }, - ContentBlockStop { - index: usize, - }, - MessageDelta { - // delta: AnthropicMessageDelta, - usage: Option, - }, - MessageStop, - Ping, - Error { - error: AnthropicError, - }, -} - -#[derive(Debug, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -enum AnthropicDelta { - TextDelta { text: String }, - InputJsonDelta { partial_json: String }, -} - -#[derive(Debug, Deserialize)] -struct AnthropicMessageDelta { - // stop_reason: Option, - // stop_sequence: Option, -} - -#[derive(Debug, Deserialize)] -struct AnthropicError { - #[serde(rename = "type")] - error_type: String, - message: String, -} - -/// Helper struct for tracking streaming tool calls -#[derive(Debug)] -struct AnthropicStreamToolCall { - /// Anthropic tool call ID - id: String, - /// Index of the tool call in the message - index: usize, - /// Name of the tool - name: String, - /// Partial input parameters (JSON stringified) - input: String, -} - -impl AnthropicStreamToolCall { - /// Convert Anthropic tool call format to ChatRsToolCall - fn convert(self, llm_tools: &[LlmTool]) -> Option { - let input = if self.input.trim().is_empty() { - "{}" - } else { - &self.input - }; - let parameters = serde_json::from_str(input).ok()?; - llm_tools - .iter() - .find(|tool| tool.name == self.name) - .map(|tool| ChatRsToolCall { - id: self.id, - tool_id: tool.tool_id, - tool_name: self.name, - tool_type: tool.tool_type, - parameters, - }) - } -} diff --git a/server/src/provider/anthropic/request.rs b/server/src/provider/anthropic/request.rs new file mode 100644 index 0000000..43af62b --- /dev/null +++ b/server/src/provider/anthropic/request.rs @@ -0,0 +1,134 @@ +use std::collections::HashMap; + +use serde::Serialize; + +use crate::{ + db::models::{ChatRsMessage, ChatRsMessageRole}, + provider::LlmTool, +}; + +pub fn build_anthropic_messages<'a>( + messages: &'a [ChatRsMessage], +) -> (Vec>, Option<&'a str>) { + let system_prompt = messages + .iter() + .rfind(|message| message.role == ChatRsMessageRole::System) + .map(|message| message.content.as_str()); + + let anthropic_messages: Vec = messages + .iter() + .filter_map(|message| { + let role = match message.role { + ChatRsMessageRole::User => "user", + ChatRsMessageRole::Tool => "user", + ChatRsMessageRole::Assistant => "assistant", + ChatRsMessageRole::System => return None, + }; + + let mut content_blocks = Vec::new(); + + // Handle tool result messages + if message.role == ChatRsMessageRole::Tool { + if let Some(executed_call) = &message.meta.tool_call { + content_blocks.push(AnthropicContentBlock::ToolResult { + tool_use_id: &executed_call.id, + content: &message.content, + }); + } + } else { + // Handle regular text content + if !message.content.is_empty() { + content_blocks.push(AnthropicContentBlock::Text { + text: &message.content, + }); + } + // Handle tool calls in assistant messages + if let Some(tool_calls) = message + .meta + .assistant + .as_ref() + .and_then(|a| a.tool_calls.as_ref()) + { + for tool_call in tool_calls { + content_blocks.push(AnthropicContentBlock::ToolUse { + id: &tool_call.id, + name: &tool_call.tool_name, + input: &tool_call.parameters, + }); + } + } + } + + if content_blocks.is_empty() { + return None; + } + + Some(AnthropicMessage { + role, + content: content_blocks, + }) + }) + .collect(); + + (anthropic_messages, system_prompt) +} + +pub fn build_anthropic_tools<'a>(tools: &'a [LlmTool]) -> Vec> { + tools + .iter() + .map(|tool| AnthropicTool { + name: &tool.name, + description: &tool.description, + input_schema: &tool.input_schema, + }) + .collect() +} + +/// Anthropic API request message +#[derive(Debug, Serialize)] +pub struct AnthropicMessage<'a> { + pub role: &'a str, + pub content: Vec>, +} + +/// Anthropic API request body +#[derive(Debug, Serialize)] +pub struct AnthropicRequest<'a> { + pub model: &'a str, + pub messages: Vec>, + pub max_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>>, +} + +/// Anthropic tool definition +#[derive(Debug, Serialize)] +pub struct AnthropicTool<'a> { + name: &'a str, + description: &'a str, + input_schema: &'a serde_json::Value, +} + +/// Anthropic content block for messages +#[derive(Debug, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum AnthropicContentBlock<'a> { + Text { + text: &'a str, + }, + ToolUse { + id: &'a str, + name: &'a str, + input: &'a HashMap, + }, + ToolResult { + tool_use_id: &'a str, + content: &'a str, + }, +} diff --git a/server/src/provider/anthropic/response.rs b/server/src/provider/anthropic/response.rs new file mode 100644 index 0000000..42163c4 --- /dev/null +++ b/server/src/provider/anthropic/response.rs @@ -0,0 +1,210 @@ +use serde::Deserialize; + +use crate::{ + db::models::ChatRsToolCall, + provider::{ + LlmPendingToolCall, LlmStreamChunk, LlmStreamChunkResult, LlmStreamError, LlmTool, LlmUsage, + }, +}; + +/// Parse chunks from an Anthropic SSE event. +pub fn parse_anthropic_event( + event: AnthropicStreamEvent, + tools: &Option>, + tool_calls: &mut Vec, +) -> Vec { + let mut chunks: Vec = Vec::new(); + match event { + AnthropicStreamEvent::MessageStart { message } => { + if let Some(usage) = message.usage { + chunks.push(Ok(LlmStreamChunk::Usage(usage.into()))); + } + } + AnthropicStreamEvent::ContentBlockStart { + content_block, + index, + } => match content_block { + AnthropicResponseContentBlock::Text { text } => { + chunks.push(Ok(LlmStreamChunk::Text(text))); + } + AnthropicResponseContentBlock::ToolUse { id, name } => { + tool_calls.push(AnthropicStreamToolCall { + id, + index, + name, + input: String::with_capacity(100), + }); + } + }, + AnthropicStreamEvent::ContentBlockDelta { delta, index } => match delta { + AnthropicDelta::TextDelta { text } => { + chunks.push(Ok(LlmStreamChunk::Text(text))); + } + AnthropicDelta::InputJsonDelta { partial_json } => { + if let Some(tool_call) = tool_calls.iter_mut().find(|tc| tc.index == index) { + tool_call.input.push_str(&partial_json); + let chunk = LlmStreamChunk::PendingToolCall(LlmPendingToolCall { + index, + tool_name: tool_call.name.clone(), + }); + chunks.push(Ok(chunk)); + } + } + }, + AnthropicStreamEvent::ContentBlockStop { index } => { + if let Some(llm_tools) = tools { + if let Some(tc) = tool_calls + .iter() + .position(|tc| tc.index == index) + .map(|i| tool_calls.swap_remove(i)) + { + if let Some(tool_call) = tc.convert(llm_tools) { + let chunk = LlmStreamChunk::ToolCalls(vec![tool_call]); + chunks.push(Ok(chunk)); + } + } + } + } + AnthropicStreamEvent::MessageDelta { usage } => { + if let Some(usage) = usage { + chunks.push(Ok(LlmStreamChunk::Usage(usage.into()))); + } + } + AnthropicStreamEvent::Error { error } => { + let error_msg = format!("{}: {}", error.error_type, error.message); + chunks.push(Err(LlmStreamError::ProviderError(error_msg))); + } + _ => {} // Ignore other events (ping, message_stop) + } + chunks +} + +/// Anthropic API response content block +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum AnthropicResponseContentBlock { + Text { text: String }, + ToolUse { id: String, name: String }, +} + +/// Anthropic API response usage +#[derive(Debug, Deserialize)] +pub struct AnthropicUsage { + input_tokens: Option, + output_tokens: Option, +} + +impl From for LlmUsage { + fn from(usage: AnthropicUsage) -> Self { + LlmUsage { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + cost: None, + } + } +} + +/// Anthropic API response +#[derive(Debug, Deserialize)] +pub struct AnthropicResponse { + pub content: Vec, + pub usage: Option, +} + +/// Anthropic stream response (message start) +#[derive(Debug, Deserialize)] +pub struct AnthropicStreamResponse { + // id: String, + // #[serde(rename = "type")] + // message_type: String, + // role: String, + // content: Vec, + // model: String, + // stop_reason: Option, + // stop_sequence: Option, + usage: Option, +} + +/// Anthropic streaming event types +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum AnthropicStreamEvent { + MessageStart { + message: AnthropicStreamResponse, + }, + ContentBlockStart { + index: usize, + content_block: AnthropicResponseContentBlock, + }, + ContentBlockDelta { + index: usize, + delta: AnthropicDelta, + }, + ContentBlockStop { + index: usize, + }, + MessageDelta { + // delta: AnthropicMessageDelta, + usage: Option, + }, + MessageStop, + Ping, + Error { + error: AnthropicError, + }, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum AnthropicDelta { + TextDelta { text: String }, + InputJsonDelta { partial_json: String }, +} + +#[derive(Debug, Deserialize)] +pub struct AnthropicMessageDelta { + // stop_reason: Option, + // stop_sequence: Option, +} + +#[derive(Debug, Deserialize)] +pub struct AnthropicError { + #[serde(rename = "type")] + error_type: String, + message: String, +} + +/// Helper struct for tracking streaming tool calls +#[derive(Debug)] +pub struct AnthropicStreamToolCall { + /// Anthropic tool call ID + id: String, + /// Index of the tool call in the message + index: usize, + /// Name of the tool + name: String, + /// Partial input parameters (JSON stringified) + input: String, +} + +impl AnthropicStreamToolCall { + /// Convert Anthropic tool call format to ChatRsToolCall + fn convert(self, llm_tools: &[LlmTool]) -> Option { + let input = if self.input.trim().is_empty() { + "{}" + } else { + &self.input + }; + let parameters = serde_json::from_str(input).ok()?; + llm_tools + .iter() + .find(|tool| tool.name == self.name) + .map(|tool| ChatRsToolCall { + id: self.id, + tool_id: tool.tool_id, + tool_name: self.name, + tool_type: tool.tool_type, + parameters, + }) + } +} diff --git a/server/src/provider/lorem.rs b/server/src/provider/lorem.rs index 85e3744..4c13aa3 100644 --- a/server/src/provider/lorem.rs +++ b/server/src/provider/lorem.rs @@ -10,8 +10,8 @@ use tokio::time::{interval, Interval}; use crate::{ db::models::ChatRsMessage, provider::{ - LlmApiProvider, LlmApiProviderSharedOptions, LlmApiStream, LlmError, LlmStreamChunk, - LlmTool, + LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmStreamChunk, + LlmStreamChunkResult, LlmStreamError, LlmTool, }, provider_models::LlmModel, }; @@ -41,7 +41,7 @@ struct LoremStream { interval: Interval, } impl Stream for LoremStream { - type Item = Result; + type Item = LlmStreamChunkResult; fn poll_next( mut self: Pin<&mut Self>, @@ -58,7 +58,9 @@ impl Stream for LoremStream { if self.index == 0 || self.index % 10 != 0 { std::task::Poll::Ready(Some(Ok(LlmStreamChunk::Text(word.to_owned())))) } else { - std::task::Poll::Ready(Some(Err(LlmError::LoremError("Test error")))) + std::task::Poll::Ready(Some(Err(LlmStreamError::ProviderError( + "Test error".into(), + )))) } } std::task::Poll::Pending => std::task::Poll::Pending, @@ -72,8 +74,8 @@ impl LlmApiProvider for LoremProvider { &self, _messages: Vec, _tools: Option>, - _options: &LlmApiProviderSharedOptions, - ) -> Result { + _options: &LlmProviderOptions, + ) -> Result { let lorem_words = vec![ "Lorem ipsum ", "dolor sit ", @@ -103,7 +105,7 @@ impl LlmApiProvider for LoremProvider { "nulla pariatur.", ]; - let stream: LlmApiStream = Box::pin(LoremStream { + let stream: LlmStream = Box::pin(LoremStream { words: lorem_words, index: 0, interval: interval(Duration::from_millis(self.config.interval.into())), @@ -117,7 +119,7 @@ impl LlmApiProvider for LoremProvider { async fn prompt( &self, _request: &str, - _options: &LlmApiProviderSharedOptions, + _options: &LlmProviderOptions, ) -> Result { Ok("Lorem ipsum".to_string()) } diff --git a/server/src/provider/openai.rs b/server/src/provider/openai.rs index e2193ae..39bd927 100644 --- a/server/src/provider/openai.rs +++ b/server/src/provider/openai.rs @@ -1,13 +1,23 @@ //! OpenAI (and OpenAI compatible) LLM provider -use rocket::async_trait; -use serde::{Deserialize, Serialize}; +mod request; +mod response; + +use rocket::{async_trait, futures::StreamExt}; +use tokio_stream::wrappers::ReceiverStream; use crate::{ - db::models::{ChatRsMessage, ChatRsMessageRole, ChatRsToolCall}, + db::models::ChatRsMessage, provider::{ - LlmApiProvider, LlmApiProviderSharedOptions, LlmApiStream, LlmError, LlmPendingToolCall, - LlmStreamChunk, LlmTool, LlmUsage, + openai::{ + request::{ + build_openai_messages, build_openai_tools, OpenAIMessage, OpenAIRequest, + OpenAIStreamOptions, + }, + response::{parse_openai_event, OpenAIResponse, OpenAIStreamToolCall}, + }, + utils::get_sse_events, + LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmStreamChunk, LlmTool, LlmUsage, }, provider_models::{LlmModel, ModelsDevService, ModelsDevServiceProvider}, }; @@ -38,152 +48,6 @@ impl OpenAIProvider { base_url: base_url.unwrap_or(OPENAI_API_BASE_URL).to_owned(), } } - - fn build_messages<'a>(&self, messages: &'a [ChatRsMessage]) -> Vec> { - messages - .iter() - .map(|message| { - let role = match message.role { - ChatRsMessageRole::User => "user", - ChatRsMessageRole::Assistant => "assistant", - ChatRsMessageRole::System => "system", - ChatRsMessageRole::Tool => "tool", - }; - let openai_message = OpenAIMessage { - role, - content: Some(&message.content), - tool_call_id: message.meta.tool_call.as_ref().map(|tc| tc.id.as_str()), - tool_calls: message - .meta - .assistant - .as_ref() - .and_then(|meta| meta.tool_calls.as_ref()) - .map(|tc| { - tc.iter() - .map(|tc| OpenAIToolCall { - id: &tc.id, - tool_type: "function", - function: OpenAIToolCallFunction { - name: &tc.tool_name, - arguments: serde_json::to_string(&tc.parameters) - .unwrap_or_default(), - }, - }) - .collect() - }), - }; - - openai_message - }) - .collect() - } - - fn build_tools<'a>(&self, tools: &'a [LlmTool]) -> Vec> { - tools - .iter() - .map(|tool| OpenAITool { - tool_type: "function", - function: OpenAIToolFunction { - name: &tool.name, - description: &tool.description, - parameters: &tool.input_schema, - strict: true, - }, - }) - .collect() - } - - async fn parse_sse_stream( - &self, - mut response: reqwest::Response, - tools: Option>, - ) -> LlmApiStream { - let stream = async_stream::stream! { - let mut buffer = String::new(); - let mut tool_calls: Vec = Vec::new(); - - while let Some(chunk) = response.chunk().await.transpose() { - match chunk { - Ok(bytes) => { - let text = String::from_utf8_lossy(&bytes); - buffer.push_str(&text); - - while let Some(line_end_idx) = buffer.find('\n') { - let line = buffer[..line_end_idx].trim_end_matches('\r'); - - if line.starts_with("data: ") { - let data = &line[6..]; // Remove "data: " prefix - if data.trim().is_empty() || data == "[DONE]" { - buffer.drain(..=line_end_idx); - continue; - } - - match serde_json::from_str::(data) { - Ok(mut response) => { - if let Some(choice) = response.choices.pop() { - if let Some(delta) = choice.delta { - if let Some(text) = delta.content { - yield Ok(LlmStreamChunk::Text(text)); - } - - if let Some(tool_calls_delta) = delta.tool_calls { - for tool_call_delta in tool_calls_delta { - if let Some(tc) = tool_calls.iter_mut().find(|tc| tc.index == tool_call_delta.index) { - if let Some(function_arguments) = tool_call_delta.function.arguments { - *tc.function.arguments.get_or_insert_default() += &function_arguments; - } - if let Some(ref tool_name) = tc.function.name { - yield Ok(LlmStreamChunk::PendingToolCall(LlmPendingToolCall { - index: tool_call_delta.index, - tool_name: tool_name.clone(), - })); - } - } else { - tool_calls.push(tool_call_delta); - } - } - } - } - } - - // Yield usage information if available - if let Some(usage) = response.usage { - yield Ok(LlmStreamChunk::Usage(usage.into())); - } - } - Err(e) => { - rocket::warn!("Failed to parse SSE event: {} | Data: {}", e, data); - } - } - } else if line.starts_with("event: ") { - let event_type = &line[7..]; - rocket::debug!("SSE event type: {}", event_type); - } else if !line.trim().is_empty() && !line.starts_with(":") { - rocket::debug!("Unexpected SSE line: {}", line); - } - - buffer.drain(..=line_end_idx); - } - } - Err(e) => { - rocket::warn!("Stream chunk error: {}", e); - yield Err(LlmError::OpenAIError(format!("Stream error: {}", e))); - break; - } - } - } - - if let Some(rs_chat_tools) = tools { - if !tool_calls.is_empty() { - yield Ok(LlmStreamChunk::ToolCalls(tool_calls.into_iter().filter_map(|tc| tc.convert(&rs_chat_tools)).collect())); - } - } - - rocket::debug!("SSE stream ended"); - }; - - Box::pin(stream) - } } #[async_trait] @@ -192,10 +56,10 @@ impl LlmApiProvider for OpenAIProvider { &self, messages: Vec, tools: Option>, - options: &LlmApiProviderSharedOptions, - ) -> Result { - let openai_messages = self.build_messages(&messages); - let openai_tools = tools.as_ref().map(|t| self.build_tools(t)); + options: &LlmProviderOptions, + ) -> Result { + let openai_messages = build_openai_messages(&messages); + let openai_tools = tools.as_ref().map(|t| build_openai_tools(t)); let request = OpenAIRequest { model: &options.model, @@ -222,24 +86,56 @@ impl LlmApiProvider for OpenAIProvider { .json(&request) .send() .await - .map_err(|e| LlmError::OpenAIError(format!("Request failed: {}", e)))?; + .map_err(|e| LlmError::ProviderError(format!("OpenAI request failed: {}", e)))?; if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); - return Err(LlmError::OpenAIError(format!( - "API error {}: {}", + return Err(LlmError::ProviderError(format!( + "OpenAI API error {}: {}", status, error_text ))); } - Ok(self.parse_sse_stream(response, tools).await) + let (tx, rx) = tokio::sync::mpsc::channel(100); + tokio::spawn(async move { + let mut tool_calls: Vec = Vec::new(); + let mut sse_event_stream = get_sse_events(response); + 'outer: while let Some(event) = sse_event_stream.next().await { + match event { + Ok(event) => { + for chunk in parse_openai_event(event, &mut tool_calls) { + if tx.send(chunk).await.is_err() { + break 'outer; // receiver dropped, stop streaming + } + } + } + Err(e) => { + if tx.send(Err(e)).await.is_err() { + break 'outer; + } + } + } + } + if !tool_calls.is_empty() { + if let Some(llm_tools) = tools { + let converted = tool_calls + .into_iter() + .filter_map(|tc| tc.convert(&llm_tools)) + .collect(); + tx.send(Ok(LlmStreamChunk::ToolCalls(converted))).await.ok(); + } + } + drop(tx); + }); + + Ok(ReceiverStream::new(rx).boxed()) } async fn prompt( &self, message: &str, - options: &LlmApiProviderSharedOptions, + options: &LlmProviderOptions, ) -> Result { let request = OpenAIRequest { model: &options.model, @@ -261,13 +157,13 @@ impl LlmApiProvider for OpenAIProvider { .json(&request) .send() .await - .map_err(|e| LlmError::OpenAIError(format!("Request failed: {}", e)))?; + .map_err(|e| LlmError::ProviderError(format!("OpenAI request failed: {}", e)))?; if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); - return Err(LlmError::OpenAIError(format!( - "API error {}: {}", + return Err(LlmError::ProviderError(format!( + "OpenAI API error {}: {}", status, error_text ))); } @@ -275,7 +171,7 @@ impl LlmApiProvider for OpenAIProvider { let mut openai_response: OpenAIResponse = response .json() .await - .map_err(|e| LlmError::OpenAIError(format!("Failed to parse response: {}", e)))?; + .map_err(|e| LlmError::ProviderError(format!("Failed to parse response: {}", e)))?; let text = openai_response .choices @@ -306,162 +202,3 @@ impl LlmApiProvider for OpenAIProvider { Ok(models) } } - -/// OpenAI API request body -#[derive(Debug, Default, Serialize)] -struct OpenAIRequest<'a> { - model: &'a str, - messages: Vec>, - #[serde(skip_serializing_if = "Option::is_none")] - max_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - max_completion_tokens: Option, - temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - stream: Option, - #[serde(skip_serializing_if = "Option::is_none")] - stream_options: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>>, -} - -/// OpenAI API request message -#[derive(Debug, Default, Serialize)] -struct OpenAIMessage<'a> { - role: &'a str, - #[serde(skip_serializing_if = "Option::is_none")] - content: Option<&'a str>, - #[serde(skip_serializing_if = "Option::is_none")] - tool_call_id: Option<&'a str>, - #[serde(skip_serializing_if = "Option::is_none")] - tool_calls: Option>>, -} - -/// OpenAI tool definition -#[derive(Debug, Serialize)] -struct OpenAITool<'a> { - #[serde(rename = "type")] - tool_type: &'a str, - function: OpenAIToolFunction<'a>, -} - -/// OpenAI tool function definition -#[derive(Debug, Serialize)] -struct OpenAIToolFunction<'a> { - name: &'a str, - description: &'a str, - parameters: &'a serde_json::Value, - strict: bool, -} - -/// OpenAI tool call in messages -#[derive(Debug, Serialize)] -struct OpenAIToolCall<'a> { - id: &'a str, - #[serde(rename = "type")] - tool_type: &'a str, - function: OpenAIToolCallFunction<'a>, -} - -/// OpenAI tool call function in messages -#[derive(Debug, Serialize)] -struct OpenAIToolCallFunction<'a> { - name: &'a str, - arguments: String, -} - -/// OpenAI API request stream options -#[derive(Debug, Serialize)] -struct OpenAIStreamOptions { - include_usage: bool, -} - -/// OpenAI API response -#[derive(Debug, Deserialize)] -struct OpenAIResponse { - choices: Vec, - usage: Option, -} - -/// OpenAI API streaming response -#[derive(Debug, Deserialize)] -struct OpenAIStreamResponse { - choices: Vec, - usage: Option, -} - -/// OpenAI API response choice -#[derive(Debug, Deserialize)] -struct OpenAIChoice { - message: Option, - delta: Option, - // finish_reason: Option, -} - -/// OpenAI API response message -#[derive(Debug, Deserialize)] -struct OpenAIResponseMessage { - // role: String, - content: Option, -} - -/// OpenAI API streaming delta -#[derive(Debug, Deserialize)] -struct OpenAIResponseDelta { - // role: Option, - content: Option, - tool_calls: Option>, -} - -/// OpenAI streaming tool call -#[derive(Debug, Deserialize)] -struct OpenAIStreamToolCall { - id: Option, - index: usize, - function: OpenAIStreamToolCallFunction, -} - -impl OpenAIStreamToolCall { - /// Convert OpenAI tool call format to ChatRsToolCall, add tool ID - fn convert(self, rs_chat_tools: &[LlmTool]) -> Option { - let id = self.id?; - let tool_name = self.function.name?; - let parameters = serde_json::from_str(&self.function.arguments?).ok()?; - rs_chat_tools - .iter() - .find(|tool| tool.name == tool_name) - .map(|tool| ChatRsToolCall { - id, - tool_id: tool.tool_id, - tool_name, - tool_type: tool.tool_type, - parameters, - }) - } -} - -/// OpenAI streaming tool call function -#[derive(Debug, Deserialize)] -struct OpenAIStreamToolCallFunction { - name: Option, - arguments: Option, -} - -/// OpenAI API response usage -#[derive(Debug, Deserialize)] -struct OpenAIUsage { - prompt_tokens: Option, - completion_tokens: Option, - cost: Option, - // total_tokens: Option, -} - -impl From for LlmUsage { - fn from(usage: OpenAIUsage) -> Self { - LlmUsage { - input_tokens: usage.prompt_tokens, - output_tokens: usage.completion_tokens, - cost: usage.cost, - } - } -} diff --git a/server/src/provider/openai/request.rs b/server/src/provider/openai/request.rs new file mode 100644 index 0000000..06a5dc2 --- /dev/null +++ b/server/src/provider/openai/request.rs @@ -0,0 +1,129 @@ +use serde::Serialize; + +use crate::{ + db::models::{ChatRsMessage, ChatRsMessageRole}, + provider::LlmTool, +}; + +pub fn build_openai_messages<'a>(messages: &'a [ChatRsMessage]) -> Vec> { + messages + .iter() + .map(|message| { + let role = match message.role { + ChatRsMessageRole::User => "user", + ChatRsMessageRole::Assistant => "assistant", + ChatRsMessageRole::System => "system", + ChatRsMessageRole::Tool => "tool", + }; + let openai_message = OpenAIMessage { + role, + content: Some(&message.content), + tool_call_id: message.meta.tool_call.as_ref().map(|tc| tc.id.as_str()), + tool_calls: message + .meta + .assistant + .as_ref() + .and_then(|meta| meta.tool_calls.as_ref()) + .map(|tc| { + tc.iter() + .map(|tc| OpenAIToolCall { + id: &tc.id, + tool_type: "function", + function: OpenAIToolCallFunction { + name: &tc.tool_name, + arguments: serde_json::to_string(&tc.parameters) + .unwrap_or_default(), + }, + }) + .collect() + }), + }; + + openai_message + }) + .collect() +} + +pub fn build_openai_tools<'a>(tools: &'a [LlmTool]) -> Vec> { + tools + .iter() + .map(|tool| OpenAITool { + tool_type: "function", + function: OpenAIToolFunction { + name: &tool.name, + description: &tool.description, + parameters: &tool.input_schema, + strict: true, + }, + }) + .collect() +} + +/// OpenAI API request body +#[derive(Debug, Default, Serialize)] +pub struct OpenAIRequest<'a> { + pub model: &'a str, + pub messages: Vec>, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>>, +} + +/// OpenAI API request stream options +#[derive(Debug, Serialize)] +pub struct OpenAIStreamOptions { + pub include_usage: bool, +} + +/// OpenAI API request message +#[derive(Debug, Default, Serialize)] +pub struct OpenAIMessage<'a> { + pub role: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>>, +} + +/// OpenAI tool definition +#[derive(Debug, Serialize)] +pub struct OpenAITool<'a> { + #[serde(rename = "type")] + tool_type: &'a str, + function: OpenAIToolFunction<'a>, +} + +/// OpenAI tool function definition +#[derive(Debug, Serialize)] +pub struct OpenAIToolFunction<'a> { + name: &'a str, + description: &'a str, + parameters: &'a serde_json::Value, + strict: bool, +} + +/// OpenAI tool call in messages +#[derive(Debug, Serialize)] +pub struct OpenAIToolCall<'a> { + id: &'a str, + #[serde(rename = "type")] + tool_type: &'a str, + function: OpenAIToolCallFunction<'a>, +} + +/// OpenAI tool call function in messages +#[derive(Debug, Serialize)] +pub struct OpenAIToolCallFunction<'a> { + name: &'a str, + arguments: String, +} diff --git a/server/src/provider/openai/response.rs b/server/src/provider/openai/response.rs new file mode 100644 index 0000000..c06d34e --- /dev/null +++ b/server/src/provider/openai/response.rs @@ -0,0 +1,135 @@ +use serde::Deserialize; + +use crate::{ + db::models::ChatRsToolCall, + provider::{LlmPendingToolCall, LlmStreamChunk, LlmStreamChunkResult, LlmTool, LlmUsage}, +}; + +/// Parse chunks from an OpenAI SSE event +pub fn parse_openai_event( + mut event: OpenAIStreamResponse, + tool_calls: &mut Vec, +) -> Vec { + let mut chunks = Vec::new(); + if let Some(delta) = event.choices.pop().and_then(|c| c.delta) { + if let Some(text) = delta.content { + chunks.push(Ok(LlmStreamChunk::Text(text))); + } + if let Some(tool_calls_delta) = delta.tool_calls { + for tool_call_delta in tool_calls_delta { + if let Some(tc) = tool_calls + .iter_mut() + .find(|tc| tc.index == tool_call_delta.index) + { + if let Some(function_arguments) = tool_call_delta.function.arguments { + *tc.function.arguments.get_or_insert_default() += &function_arguments; + } + if let Some(ref tool_name) = tc.function.name { + let chunk = LlmStreamChunk::PendingToolCall(LlmPendingToolCall { + index: tool_call_delta.index, + tool_name: tool_name.clone(), + }); + chunks.push(Ok(chunk)); + } + } else { + tool_calls.push(tool_call_delta); + } + } + } + } + if let Some(usage) = event.usage { + chunks.push(Ok(LlmStreamChunk::Usage(usage.into()))); + } + + chunks +} + +/// OpenAI API response +#[derive(Debug, Deserialize)] +pub struct OpenAIResponse { + pub choices: Vec, + pub usage: Option, +} + +/// OpenAI API streaming response +#[derive(Debug, Deserialize)] +pub struct OpenAIStreamResponse { + choices: Vec, + usage: Option, +} + +/// OpenAI API response choice +#[derive(Debug, Deserialize)] +pub struct OpenAIChoice { + pub message: Option, + pub delta: Option, + // finish_reason: Option, +} + +/// OpenAI API response message +#[derive(Debug, Deserialize)] +pub struct OpenAIResponseMessage { + // role: String, + pub content: Option, +} + +/// OpenAI API streaming delta +#[derive(Debug, Deserialize)] +pub struct OpenAIResponseDelta { + // role: Option, + content: Option, + tool_calls: Option>, +} + +/// OpenAI streaming tool call +#[derive(Debug, Deserialize)] +pub struct OpenAIStreamToolCall { + id: Option, + index: usize, + function: OpenAIStreamToolCallFunction, +} + +impl OpenAIStreamToolCall { + /// Convert OpenAI tool call format to ChatRsToolCall, add tool ID + pub fn convert(self, rs_chat_tools: &[LlmTool]) -> Option { + let id = self.id?; + let tool_name = self.function.name?; + let parameters = serde_json::from_str(&self.function.arguments?).ok()?; + rs_chat_tools + .iter() + .find(|tool| tool.name == tool_name) + .map(|tool| ChatRsToolCall { + id, + tool_id: tool.tool_id, + tool_name, + tool_type: tool.tool_type, + parameters, + }) + } +} + +/// OpenAI streaming tool call function +#[derive(Debug, Deserialize)] +struct OpenAIStreamToolCallFunction { + name: Option, + arguments: Option, +} + +/// OpenAI API response usage +#[derive(Debug, Deserialize)] +pub struct OpenAIUsage { + prompt_tokens: Option, + completion_tokens: Option, + cost: Option, + // total_tokens: Option, +} + +impl From for LlmUsage { + fn from(usage: OpenAIUsage) -> Self { + LlmUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cost: usage.cost, + } + } +} diff --git a/server/src/provider/utils.rs b/server/src/provider/utils.rs new file mode 100644 index 0000000..8c2b102 --- /dev/null +++ b/server/src/provider/utils.rs @@ -0,0 +1,35 @@ +use rocket::futures::TryStreamExt; +use serde::de::DeserializeOwned; +use tokio_stream::{Stream, StreamExt}; +use tokio_util::{ + codec::{FramedRead, LinesCodec}, + io::StreamReader, +}; + +use crate::provider::LlmStreamError; + +/// Get a stream of deserialized events from a provider SSE stream. +pub fn get_sse_events( + response: reqwest::Response, +) -> impl Stream> { + let stream_reader = StreamReader::new(response.bytes_stream().map_err(std::io::Error::other)); + let line_reader = FramedRead::new(stream_reader, LinesCodec::new()); + + line_reader.filter_map(|line_result| { + match line_result { + Ok(line) => { + if line.len() >= 6 && line.as_bytes().starts_with(b"data: ") { + let data = &line[6..]; // Skip "data: " prefix + if data.trim_start().is_empty() || data == "[DONE]" { + None // Skip empty lines and termination markers + } else { + Some(serde_json::from_str::(data).map_err(LlmStreamError::Parsing)) + } + } else { + None // Ignore non-data lines + } + } + Err(e) => Some(Err(LlmStreamError::Decoding(e))), + } + }) +} diff --git a/server/src/stream.rs b/server/src/stream.rs index 3684e43..65be5e6 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -20,8 +20,6 @@ use rocket::{ use rocket_okapi::OpenApiFromRequest; use uuid::Uuid; -use crate::provider::LlmError; - /// Get the key prefix for the user's chat streams in Redis fn get_chat_stream_prefix(user_id: &Uuid) -> String { format!("user:{}:chat:", user_id) @@ -36,7 +34,7 @@ fn get_chat_stream_key(user_id: &Uuid, session_id: &Uuid) -> String { pub async fn get_current_chat_streams( redis: &fred::clients::Client, user_id: &Uuid, -) -> Result, LlmError> { +) -> FredResult> { let prefix = get_chat_stream_prefix(user_id); let pattern = format!("{}*", prefix); let (_, keys): (String, Vec) = redis @@ -75,14 +73,12 @@ pub async fn cancel_current_chat_stream( /// Request guard to extract the Last-Event-ID from the request headers #[derive(OpenApiFromRequest)] pub struct LastEventId(String); - impl std::ops::Deref for LastEventId { type Target = str; fn deref(&self) -> &Self::Target { &self.0 } } - #[async_trait] impl<'r> FromRequest<'r> for LastEventId { type Error = (); diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 81be000..d1917fd 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -10,7 +10,7 @@ use uuid::Uuid; use crate::{ db::models::ChatRsToolCall, - provider::{LlmApiStream, LlmError, LlmPendingToolCall, LlmStreamChunk, LlmUsage}, + provider::{LlmPendingToolCall, LlmStream, LlmStreamChunk, LlmStreamError, LlmUsage}, redis::ExclusiveRedisClient, stream::get_chat_stream_key, }; @@ -40,7 +40,7 @@ pub struct LlmStreamWriter { /// Accumulated tool calls from the assistant. tool_calls: Option>, /// Accumulated errors during the stream from the LLM provider. - errors: Option>, + errors: Option>, /// Accumulated usage information from the LLM provider. usage: Option, } @@ -111,7 +111,7 @@ impl LlmStreamWriter { /// chunks to a Redis stream, and return the final accumulated response. pub async fn process( &mut self, - mut stream: LlmApiStream, + mut stream: LlmStream, ) -> ( Option, Option>, @@ -136,14 +136,14 @@ impl LlmStreamWriter { Ok(Some(Err(err))) => self.process_error(err), Ok(None) => break, Err(_) => { - self.process_error(LlmError::StreamTimeout); + self.process_error(LlmStreamError::StreamTimeout); break; } } if self.should_flush(&last_flush_time) { if let Err(err) = self.flush_chunk().await { - if matches!(err, LlmError::StreamNotFound) { + if matches!(err, LlmStreamError::StreamCancelled) { self.errors.get_or_insert_default().push(err); cancelled = true; break; @@ -207,7 +207,7 @@ impl LlmStreamWriter { } } - fn process_error(&mut self, err: LlmError) { + fn process_error(&mut self, err: LlmStreamError) { self.current_chunk.error = Some(err.to_string()); self.errors.get_or_insert_default().push(err); } @@ -220,7 +220,9 @@ impl LlmStreamWriter { last_flush_time.elapsed() > FLUSH_INTERVAL || text.is_some_and(|t| t.len() > MAX_CHUNK_SIZE) } - async fn flush_chunk(&mut self) -> Result<(), LlmError> { + /// Flushes the current chunk to the Redis stream. Returns a `LlmStreamError::StreamCancelled` error + /// if the stream has been deleted or cancelled. + async fn flush_chunk(&mut self) -> Result<(), LlmStreamError> { let chunk_state = std::mem::take(&mut self.current_chunk); let mut chunks: Vec = Vec::with_capacity(2); @@ -248,11 +250,12 @@ impl LlmStreamWriter { self.add_to_redis_stream(entries).await } - /// Adds a new entry to the Redis stream. Returns a `LlmError::StreamNotFound` error if the stream has been deleted or cancelled. + /// Adds new entries to the Redis stream. Returns a `LlmStreamError::StreamCancelled` error if the + /// stream has been deleted or cancelled. async fn add_to_redis_stream( &self, entries: Vec>, - ) -> Result<(), LlmError> { + ) -> Result<(), LlmStreamError> { let pipeline = self.redis.pipeline(); for entry in entries { let _: () = pipeline @@ -263,7 +266,7 @@ impl LlmStreamWriter { // Check for `nil` responses indicating the stream has been deleted/cancelled if res.iter().any(|r| r.is_null()) { - Err(LlmError::StreamNotFound) + Err(LlmStreamError::StreamCancelled) } else { Ok(()) } @@ -295,7 +298,7 @@ impl LlmStreamWriter { mod tests { use super::*; use crate::{ - provider::{lorem::LoremProvider, LlmApiProvider, LlmApiProviderSharedOptions}, + provider::{lorem::LoremProvider, LlmApiProvider, LlmProviderOptions}, redis::{ExclusiveClientManager, ExclusiveClientPool}, stream::{cancel_current_chat_stream, check_chat_stream_exists}, }; @@ -345,7 +348,7 @@ mod tests { // Create Lorem provider and get stream let lorem = LoremProvider::new(); let stream = lorem - .chat_stream(vec![], None, &LlmApiProviderSharedOptions::default()) + .chat_stream(vec![], None, &LlmProviderOptions::default()) .await .expect("Failed to create lorem stream"); @@ -392,7 +395,7 @@ mod tests { .map(|text| Ok(LlmStreamChunk::Text(text.into()))), ); - let stream: LlmApiStream = Box::pin(chunk_stream); + let stream: LlmStream = Box::pin(chunk_stream); let (text, _, _, _, cancelled) = writer.process(stream).await; assert!(text.is_some()); @@ -415,11 +418,11 @@ mod tests { // Create a stream that produces an error let error_stream = tokio_stream::iter(vec![ Ok(LlmStreamChunk::Text("Hello".to_string())), - Err(LlmError::LoremError("Test error")), + Err(LlmStreamError::ProviderError("Test error".into())), Ok(LlmStreamChunk::Text(" World".to_string())), ]); - let stream: LlmApiStream = Box::pin(error_stream); + let stream: LlmStream = Box::pin(error_stream); let (text, _, _, errors, cancelled) = writer.process(stream).await; assert!(text.is_some()); @@ -446,9 +449,9 @@ mod tests { assert!(writer.start().await.is_ok()); // Create a stream that hangs (never yields anything) - let hanging_stream = tokio_stream::pending::>(); + let hanging_stream = tokio_stream::pending::>(); - let stream: LlmApiStream = Box::pin(hanging_stream); + let stream: LlmStream = Box::pin(hanging_stream); // This should timeout due to LLM_TIMEOUT let start = std::time::Instant::now(); @@ -517,7 +520,7 @@ mod tests { })), ]); - let stream: LlmApiStream = Box::pin(usage_stream); + let stream: LlmStream = Box::pin(usage_stream); let (text, _, usage, _, cancelled) = writer.process(stream).await; assert!(text.is_some()); @@ -558,7 +561,7 @@ mod tests { // Create a simple stream let simple_stream = tokio_stream::iter(vec![Ok(LlmStreamChunk::Text("Test chunk".into()))]); - let stream: LlmApiStream = Box::pin(simple_stream); + let stream: LlmStream = Box::pin(simple_stream); writer.process(stream).await; writer.flush_chunk().await.ok(); diff --git a/server/src/utils/generate_title.rs b/server/src/utils/generate_title.rs index c8ce15a..a517f1f 100644 --- a/server/src/utils/generate_title.rs +++ b/server/src/utils/generate_title.rs @@ -5,14 +5,14 @@ use uuid::Uuid; use crate::{ db::{models::UpdateChatRsSession, services::ChatDbService, DbConnection, DbPool}, errors::ApiError, - provider::{LlmApiProvider, LlmApiProviderSharedOptions, DEFAULT_TEMPERATURE}, + provider::{LlmApiProvider, LlmProviderOptions, DEFAULT_TEMPERATURE}, }; const TITLE_TOKENS: u32 = 20; const TITLE_PROMPT: &str = "This is the first message sent by a human in a chat session with an AI chatbot. \ Please generate a short title for the session (3-7 words) in plain text \ - (no quotes or prefixes)."; + (no quotes or prefixes)"; /// Spawn a task to generate a title for the chat session pub fn generate_title( @@ -45,7 +45,7 @@ async fn generate( model: String, pool: DbPool, ) -> Result<(), ApiError> { - let provider_options = LlmApiProviderSharedOptions { + let provider_options = LlmProviderOptions { model, temperature: Some(DEFAULT_TEMPERATURE), max_tokens: Some(TITLE_TOKENS), From 9cafe7c39c89142a4af62dc56524632bb0fccdfb Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Thu, 28 Aug 2025 05:36:48 -0400 Subject: [PATCH 37/46] server: faster flush interval to Redis stream --- server/src/stream/llm_writer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index d1917fd..4751618 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -16,7 +16,7 @@ use crate::{ }; /// Interval at which chunks are flushed to the Redis stream. -const FLUSH_INTERVAL: Duration = Duration::from_millis(500); +const FLUSH_INTERVAL: Duration = Duration::from_millis(300); /// Max accumulated size of the text chunk before it is automatically flushed to Redis. const MAX_CHUNK_SIZE: usize = 200; /// Expiration in seconds set on the Redis stream (normally, the Redis stream will be deleted before this) From 7bbae205d613f4c21a16cf92dfec36b5019db3b7 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Thu, 28 Aug 2025 23:20:08 -0400 Subject: [PATCH 38/46] server: use async_stream in providers to avoid spawning extra task --- server/src/provider/anthropic.rs | 52 +++++++++-------------- server/src/provider/anthropic/response.rs | 23 +++++----- server/src/provider/openai.rs | 47 +++++++++----------- server/src/provider/openai/request.rs | 2 + server/src/provider/openai/response.rs | 9 +++- 5 files changed, 61 insertions(+), 72 deletions(-) diff --git a/server/src/provider/anthropic.rs b/server/src/provider/anthropic.rs index 162db51..bcea95a 100644 --- a/server/src/provider/anthropic.rs +++ b/server/src/provider/anthropic.rs @@ -3,29 +3,25 @@ mod request; mod response; -use rocket::{async_trait, futures::StreamExt}; -use tokio_stream::wrappers::ReceiverStream; +use rocket::{async_stream, async_trait, futures::StreamExt}; use crate::{ db::models::ChatRsMessage, provider::{ - anthropic::{ - request::{ - build_anthropic_messages, build_anthropic_tools, AnthropicContentBlock, - AnthropicMessage, AnthropicRequest, - }, - response::{ - parse_anthropic_event, AnthropicResponse, AnthropicResponseContentBlock, - AnthropicStreamToolCall, - }, - }, - utils::get_sse_events, - LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmTool, LlmUsage, - DEFAULT_MAX_TOKENS, + utils::get_sse_events, LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmTool, + LlmUsage, DEFAULT_MAX_TOKENS, }, provider_models::{LlmModel, ModelsDevService, ModelsDevServiceProvider}, }; +use { + request::{ + build_anthropic_messages, build_anthropic_tools, AnthropicContentBlock, AnthropicMessage, + AnthropicRequest, + }, + response::{parse_anthropic_event, AnthropicResponse, AnthropicResponseContentBlock}, +}; + const MESSAGES_API_URL: &str = "https://api.anthropic.com/v1/messages"; const API_VERSION: &str = "2023-06-01"; @@ -90,30 +86,22 @@ impl LlmApiProvider for AnthropicProvider { ))); } - let (tx, rx) = tokio::sync::mpsc::channel(100); - tokio::spawn(async move { - let mut tool_calls: Vec = Vec::new(); + let stream = async_stream::stream! { let mut sse_event_stream = get_sse_events(response); - 'outer: while let Some(event_result) = sse_event_stream.next().await { + let mut tool_calls = Vec::new(); + while let Some(event_result) = sse_event_stream.next().await { match event_result { Ok(event) => { - for chunk in parse_anthropic_event(event, &tools, &mut tool_calls) { - if tx.send(chunk).await.is_err() { - break 'outer; // receiver dropped, stop streaming - } + if let Some(chunk) = parse_anthropic_event(event, tools.as_ref(), &mut tool_calls) { + yield chunk; } - } - Err(e) => { - if tx.send(Err(e)).await.is_err() { - break 'outer; - } - } + }, + Err(e) => yield Err(e), } } - drop(tx); - }); + }; - Ok(ReceiverStream::new(rx).boxed()) + Ok(stream.boxed()) } async fn prompt( diff --git a/server/src/provider/anthropic/response.rs b/server/src/provider/anthropic/response.rs index 42163c4..daf3e04 100644 --- a/server/src/provider/anthropic/response.rs +++ b/server/src/provider/anthropic/response.rs @@ -7,17 +7,16 @@ use crate::{ }, }; -/// Parse chunks from an Anthropic SSE event. +/// Parse an Anthropic SSE event. pub fn parse_anthropic_event( event: AnthropicStreamEvent, - tools: &Option>, + tools: Option<&Vec>, tool_calls: &mut Vec, -) -> Vec { - let mut chunks: Vec = Vec::new(); +) -> Option { match event { AnthropicStreamEvent::MessageStart { message } => { if let Some(usage) = message.usage { - chunks.push(Ok(LlmStreamChunk::Usage(usage.into()))); + return Some(Ok(LlmStreamChunk::Usage(usage.into()))); } } AnthropicStreamEvent::ContentBlockStart { @@ -25,7 +24,7 @@ pub fn parse_anthropic_event( index, } => match content_block { AnthropicResponseContentBlock::Text { text } => { - chunks.push(Ok(LlmStreamChunk::Text(text))); + return Some(Ok(LlmStreamChunk::Text(text))); } AnthropicResponseContentBlock::ToolUse { id, name } => { tool_calls.push(AnthropicStreamToolCall { @@ -38,7 +37,7 @@ pub fn parse_anthropic_event( }, AnthropicStreamEvent::ContentBlockDelta { delta, index } => match delta { AnthropicDelta::TextDelta { text } => { - chunks.push(Ok(LlmStreamChunk::Text(text))); + return Some(Ok(LlmStreamChunk::Text(text))); } AnthropicDelta::InputJsonDelta { partial_json } => { if let Some(tool_call) = tool_calls.iter_mut().find(|tc| tc.index == index) { @@ -47,7 +46,7 @@ pub fn parse_anthropic_event( index, tool_name: tool_call.name.clone(), }); - chunks.push(Ok(chunk)); + return Some(Ok(chunk)); } } }, @@ -60,23 +59,23 @@ pub fn parse_anthropic_event( { if let Some(tool_call) = tc.convert(llm_tools) { let chunk = LlmStreamChunk::ToolCalls(vec![tool_call]); - chunks.push(Ok(chunk)); + return Some(Ok(chunk)); } } } } AnthropicStreamEvent::MessageDelta { usage } => { if let Some(usage) = usage { - chunks.push(Ok(LlmStreamChunk::Usage(usage.into()))); + return Some(Ok(LlmStreamChunk::Usage(usage.into()))); } } AnthropicStreamEvent::Error { error } => { let error_msg = format!("{}: {}", error.error_type, error.message); - chunks.push(Err(LlmStreamError::ProviderError(error_msg))); + return Some(Err(LlmStreamError::ProviderError(error_msg))); } _ => {} // Ignore other events (ping, message_stop) } - chunks + None } /// Anthropic API response content block diff --git a/server/src/provider/openai.rs b/server/src/provider/openai.rs index 39bd927..5153d45 100644 --- a/server/src/provider/openai.rs +++ b/server/src/provider/openai.rs @@ -3,25 +3,25 @@ mod request; mod response; -use rocket::{async_trait, futures::StreamExt}; -use tokio_stream::wrappers::ReceiverStream; +use rocket::{async_stream, async_trait, futures::StreamExt}; use crate::{ db::models::ChatRsMessage, provider::{ - openai::{ - request::{ - build_openai_messages, build_openai_tools, OpenAIMessage, OpenAIRequest, - OpenAIStreamOptions, - }, - response::{parse_openai_event, OpenAIResponse, OpenAIStreamToolCall}, - }, - utils::get_sse_events, - LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmStreamChunk, LlmTool, LlmUsage, + utils::get_sse_events, LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, + LlmStreamChunk, LlmTool, LlmUsage, }, provider_models::{LlmModel, ModelsDevService, ModelsDevServiceProvider}, }; +use { + request::{ + build_openai_messages, build_openai_tools, OpenAIMessage, OpenAIRequest, + OpenAIStreamOptions, + }, + response::{parse_openai_event, OpenAIResponse, OpenAIStreamToolCall}, +}; + const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1"; const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1"; @@ -71,6 +71,7 @@ impl LlmApiProvider for OpenAIProvider { && self.base_url == OPENAI_API_BASE_URL) .then(|| options.max_tokens.expect("already checked for Some value")), temperature: options.temperature, + store: (self.base_url == OPENAI_API_BASE_URL).then_some(false), stream: Some(true), stream_options: Some(OpenAIStreamOptions { include_usage: true, @@ -97,24 +98,17 @@ impl LlmApiProvider for OpenAIProvider { ))); } - let (tx, rx) = tokio::sync::mpsc::channel(100); - tokio::spawn(async move { - let mut tool_calls: Vec = Vec::new(); + let stream = async_stream::stream! { let mut sse_event_stream = get_sse_events(response); - 'outer: while let Some(event) = sse_event_stream.next().await { + let mut tool_calls: Vec = Vec::new(); + while let Some(event) = sse_event_stream.next().await { match event { Ok(event) => { for chunk in parse_openai_event(event, &mut tool_calls) { - if tx.send(chunk).await.is_err() { - break 'outer; // receiver dropped, stop streaming - } - } - } - Err(e) => { - if tx.send(Err(e)).await.is_err() { - break 'outer; + yield chunk; } } + Err(e) => yield Err(e), } } if !tool_calls.is_empty() { @@ -123,13 +117,12 @@ impl LlmApiProvider for OpenAIProvider { .into_iter() .filter_map(|tc| tc.convert(&llm_tools)) .collect(); - tx.send(Ok(LlmStreamChunk::ToolCalls(converted))).await.ok(); + yield Ok(LlmStreamChunk::ToolCalls(converted)); } } - drop(tx); - }); + }; - Ok(ReceiverStream::new(rx).boxed()) + Ok(stream.boxed()) } async fn prompt( diff --git a/server/src/provider/openai/request.rs b/server/src/provider/openai/request.rs index 06a5dc2..9feafcf 100644 --- a/server/src/provider/openai/request.rs +++ b/server/src/provider/openai/request.rs @@ -70,6 +70,8 @@ pub struct OpenAIRequest<'a> { pub max_completion_tokens: Option, pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub stream: Option, #[serde(skip_serializing_if = "Option::is_none")] pub stream_options: Option, diff --git a/server/src/provider/openai/response.rs b/server/src/provider/openai/response.rs index c06d34e..534d45d 100644 --- a/server/src/provider/openai/response.rs +++ b/server/src/provider/openai/response.rs @@ -10,7 +10,7 @@ pub fn parse_openai_event( mut event: OpenAIStreamResponse, tool_calls: &mut Vec, ) -> Vec { - let mut chunks = Vec::new(); + let mut chunks = Vec::with_capacity(1); if let Some(delta) = event.choices.pop().and_then(|c| c.delta) { if let Some(text) = delta.content { chunks.push(Ok(LlmStreamChunk::Text(text))); @@ -32,6 +32,13 @@ pub fn parse_openai_event( chunks.push(Ok(chunk)); } } else { + if let Some(ref tool_name) = tool_call_delta.function.name { + let chunk = LlmStreamChunk::PendingToolCall(LlmPendingToolCall { + index: tool_call_delta.index, + tool_name: tool_name.clone(), + }); + chunks.push(Ok(chunk)); + } tool_calls.push(tool_call_delta); } } From c8fb8504929a2ba8e6b575a0079f5d310ad26205 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Thu, 28 Aug 2025 23:24:55 -0400 Subject: [PATCH 39/46] server: doc updates --- ARCHITECTURE.md | 1 - server/src/stream/llm_writer.rs | 8 +++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 500f6e2..c3e3e29 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -36,7 +36,6 @@ The core component that processes LLM provider streams and manages Redis stream **Key Features:** - **Batching**: Accumulates chunks from the provider stream, up to a max length or timeout - **Background Pings**: Sends regular keepalive pings -- **Timeout Detection**: 20-second timeout for idle LLM streams - **Database Integration**: Saves final responses to PostgreSQL #### 2. Redis and SSE Stream Structure diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 4751618..5ca3129 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -134,9 +134,15 @@ impl LlmStreamWriter { LlmStreamChunk::Usage(usage) => self.process_usage(usage), }, Ok(Some(Err(err))) => self.process_error(err), - Ok(None) => break, + Ok(None) => { + // stream ended + self.flush_chunk().await.ok(); + break; + } Err(_) => { + // timed out waiting for provider response self.process_error(LlmStreamError::StreamTimeout); + self.flush_chunk().await.ok(); break; } } From 84c90ed40272ce7b0fffddbb7138b1e6287c38f1 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 29 Aug 2025 00:41:27 -0400 Subject: [PATCH 40/46] add ollama provider --- server/src/api/tool.rs | 1 + server/src/db/models/provider.rs | 3 + server/src/db/models/tool.rs | 3 + server/src/provider.rs | 10 +- server/src/provider/ollama.rs | 162 ++++++++++++++++++++++++ server/src/provider/ollama/request.rs | 166 +++++++++++++++++++++++++ server/src/provider/ollama/response.rs | 155 +++++++++++++++++++++++ 7 files changed, 499 insertions(+), 1 deletion(-) create mode 100644 server/src/provider/ollama.rs create mode 100644 server/src/provider/ollama/request.rs create mode 100644 server/src/provider/ollama/response.rs diff --git a/server/src/api/tool.rs b/server/src/api/tool.rs index b674d63..b9f9240 100644 --- a/server/src/api/tool.rs +++ b/server/src/api/tool.rs @@ -256,6 +256,7 @@ async fn execute_tool( tool_call: Some(ChatRsExecutedToolCall { id: tool_call.id, tool_id: tool_call.tool_id, + tool_name: tool_call.tool_name, tool_type: tool_call.tool_type, response_format: format, is_error, diff --git a/server/src/db/models/provider.rs b/server/src/db/models/provider.rs index 6ecdbb3..6ce9af7 100644 --- a/server/src/db/models/provider.rs +++ b/server/src/db/models/provider.rs @@ -51,6 +51,7 @@ pub struct UpdateChatRsProvider<'a> { pub enum ChatRsProviderType { Anthropic, Openai, + Ollama, Lorem, } @@ -61,6 +62,7 @@ impl TryFrom<&str> for ChatRsProviderType { match value { "anthropic" => Ok(ChatRsProviderType::Anthropic), "openai" => Ok(ChatRsProviderType::Openai), + "ollama" => Ok(ChatRsProviderType::Ollama), "lorem" => Ok(ChatRsProviderType::Lorem), _ => Err(LlmError::UnsupportedProvider), } @@ -72,6 +74,7 @@ impl From<&ChatRsProviderType> for &str { match value { ChatRsProviderType::Anthropic => "anthropic", ChatRsProviderType::Openai => "openai", + ChatRsProviderType::Ollama => "ollama", ChatRsProviderType::Lorem => "lorem", } } diff --git a/server/src/db/models/tool.rs b/server/src/db/models/tool.rs index 39b2d29..a328ea5 100644 --- a/server/src/db/models/tool.rs +++ b/server/src/db/models/tool.rs @@ -77,6 +77,9 @@ pub struct ChatRsExecutedToolCall { pub id: String, /// ID of the tool used pub tool_id: Uuid, + /// Name of the tool used + #[serde(default)] + pub tool_name: String, /// Type of the tool used #[serde(default)] pub tool_type: LlmToolType, diff --git a/server/src/provider.rs b/server/src/provider.rs index fd7aacb..bc4fb32 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -2,6 +2,7 @@ pub mod anthropic; pub mod lorem; +pub mod ollama; pub mod openai; mod utils; @@ -14,7 +15,10 @@ use uuid::Uuid; use crate::{ db::models::{ChatRsMessage, ChatRsProviderType, ChatRsToolCall}, - provider::{anthropic::AnthropicProvider, lorem::LoremProvider, openai::OpenAIProvider}, + provider::{ + anthropic::AnthropicProvider, lorem::LoremProvider, ollama::OllamaProvider, + openai::OpenAIProvider, + }, provider_models::LlmModel, }; @@ -164,6 +168,10 @@ pub fn build_llm_provider_api( redis, api_key.ok_or(LlmError::MissingApiKey)?, ))), + ChatRsProviderType::Ollama => Ok(Box::new(OllamaProvider::new( + http_client, + base_url.unwrap_or("http://localhost:11434"), + ))), ChatRsProviderType::Lorem => Ok(Box::new(LoremProvider::new())), } } diff --git a/server/src/provider/ollama.rs b/server/src/provider/ollama.rs new file mode 100644 index 0000000..00cb93a --- /dev/null +++ b/server/src/provider/ollama.rs @@ -0,0 +1,162 @@ +//! Ollama LLM provider + +mod request; +mod response; + +use rocket::{async_stream, async_trait, futures::StreamExt}; + +use crate::{ + db::models::ChatRsMessage, + provider::{ + ollama::{ + request::{ + build_ollama_messages, build_ollama_tools, OllamaChatRequest, + OllamaCompletionRequest, OllamaOptions, + }, + response::{parse_ollama_event, OllamaCompletionResponse, OllamaToolCall}, + }, + utils::get_sse_events, + LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmStreamChunk, LlmTool, LlmUsage, + }, + provider_models::LlmModel, +}; + +const CHAT_API_URL: &str = "/api/chat"; + +/// Ollama chat provider +#[derive(Debug, Clone)] +pub struct OllamaProvider { + client: reqwest::Client, + base_url: String, +} + +impl OllamaProvider { + pub fn new(http_client: &reqwest::Client, base_url: &str) -> Self { + Self { + client: http_client.clone(), + base_url: base_url.trim_end_matches('/').to_string(), + } + } +} + +#[async_trait] +impl LlmApiProvider for OllamaProvider { + async fn chat_stream( + &self, + messages: Vec, + tools: Option>, + options: &LlmProviderOptions, + ) -> Result { + let ollama_messages = build_ollama_messages(&messages); + let ollama_tools = tools.as_ref().map(|t| build_ollama_tools(t)); + let ollama_options = OllamaOptions { + temperature: options.temperature, + num_predict: options.max_tokens, + ..Default::default() + }; + let request = OllamaChatRequest { + model: &options.model, + messages: ollama_messages, + tools: ollama_tools, + stream: Some(true), + options: Some(ollama_options), + }; + + let response = self + .client + .post(format!("{}{}", self.base_url, CHAT_API_URL)) + .header("content-type", "application/json") + .json(&request) + .send() + .await + .map_err(|e| LlmError::ProviderError(format!("Ollama request failed: {}", e)))?; + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + return Err(LlmError::ProviderError(format!( + "Ollama API error {}: {}", + status, error_text + ))); + } + + let stream = async_stream::stream! { + let mut sse_event_stream = get_sse_events(response); + let mut tool_calls: Vec = Vec::new(); + while let Some(event) = sse_event_stream.next().await { + match event { + Ok(event) => { + for chunk in parse_ollama_event(event, &mut tool_calls) { + yield chunk; + } + } + Err(e) => yield Err(e), + } + } + if !tool_calls.is_empty() { + if let Some(llm_tools) = tools { + let converted = tool_calls + .into_iter() + .filter_map(|tc| tc.function.convert(&llm_tools)) + .collect(); + yield Ok(LlmStreamChunk::ToolCalls(converted)); + } + } + }; + + Ok(stream.boxed()) + } + + async fn prompt( + &self, + message: &str, + options: &LlmProviderOptions, + ) -> Result { + let ollama_options = OllamaOptions { + temperature: options.temperature, + num_predict: options.max_tokens, + ..Default::default() + }; + let request = OllamaCompletionRequest { + model: &options.model, + prompt: message, + stream: Some(false), + options: Some(ollama_options), + }; + let response = self + .client + .post(format!("{}{}", self.base_url, CHAT_API_URL)) + .header("content-type", "application/json") + .json(&request) + .send() + .await + .map_err(|e| LlmError::ProviderError(format!("Ollama request failed: {}", e)))?; + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + return Err(LlmError::ProviderError(format!( + "Ollama API error {}: {}", + status, error_text + ))); + } + + let ollama_response: OllamaCompletionResponse = response + .json() + .await + .map_err(|e| LlmError::ProviderError(format!("Failed to parse response: {}", e)))?; + if let Some(usage) = Option::::from(&ollama_response) { + println!("Prompt usage: {:?}", usage); + } + if ollama_response.response.is_empty() { + return Err(LlmError::NoResponse); + } + + Ok(ollama_response.response) + } + + async fn list_models(&self) -> Result, LlmError> { + // For now, return an empty list since Ollama doesn't appear in models.dev + // In a real implementation, you might call Ollama's /api/tags endpoint + // or maintain a static list of supported models + Ok(Vec::new()) + } +} diff --git a/server/src/provider/ollama/request.rs b/server/src/provider/ollama/request.rs new file mode 100644 index 0000000..520b41f --- /dev/null +++ b/server/src/provider/ollama/request.rs @@ -0,0 +1,166 @@ +//! Ollama API request structures + +use serde::Serialize; + +use crate::{ + db::models::{ChatRsMessage, ChatRsMessageRole}, + provider::LlmTool, + tools::ToolParameters, +}; + +/// Convert ChatRsMessages to Ollama messages +pub fn build_ollama_messages(messages: &[ChatRsMessage]) -> Vec { + messages + .iter() + .map(|msg| { + let role = match msg.role { + ChatRsMessageRole::User => "user", + ChatRsMessageRole::Assistant => "assistant", + ChatRsMessageRole::System => "system", + ChatRsMessageRole::Tool => "tool", + }; + + let mut ollama_msg = OllamaMessage { + role, + content: &msg.content, + tool_calls: None, + tool_name: None, + }; + + // Handle tool calls in assistant messages + if msg.role == ChatRsMessageRole::Assistant { + if let Some(msg_tool_calls) = msg + .meta + .assistant + .as_ref() + .and_then(|m| m.tool_calls.as_ref()) + { + let tool_calls = msg_tool_calls + .iter() + .map(|tc| OllamaToolCall { + function: OllamaToolFunction { + name: &tc.tool_name, + arguments: &tc.parameters, + }, + }) + .collect(); + ollama_msg.tool_calls = Some(tool_calls); + } + } + + // Handle tool messages (results from tool calls) + if msg.role == ChatRsMessageRole::Tool { + if let Some(ref tool_call) = msg.meta.tool_call { + ollama_msg.tool_name = Some(&tool_call.tool_name); + } + } + + ollama_msg + }) + .collect() +} + +/// Convert LlmTools to Ollama tools +pub fn build_ollama_tools(tools: &[LlmTool]) -> Vec { + tools + .iter() + .map(|tool| OllamaTool { + r#type: "function", + function: OllamaToolSpec { + name: &tool.name, + description: &tool.description, + parameters: &tool.input_schema, + }, + }) + .collect() +} + +/// Ollama chat request structure +#[derive(Debug, Serialize)] +pub struct OllamaChatRequest<'a> { + pub model: &'a str, + pub messages: Vec>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, +} + +/// Ollama completion request structure +#[derive(Debug, Serialize)] +pub struct OllamaCompletionRequest<'a> { + pub model: &'a str, + pub prompt: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, +} + +/// Ollama chat message +#[derive(Debug, Serialize)] +pub struct OllamaMessage<'a> { + pub role: &'a str, + pub content: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_name: Option<&'a str>, +} + +/// Ollama tool call in a message +#[derive(Debug, Serialize)] +pub struct OllamaToolCall<'a> { + pub function: OllamaToolFunction<'a>, +} + +/// Ollama tool function +#[derive(Debug, Serialize)] +pub struct OllamaToolFunction<'a> { + pub name: &'a str, + pub arguments: &'a ToolParameters, +} + +/// Ollama tool definition +#[derive(Debug, Serialize)] +pub struct OllamaTool<'a> { + pub r#type: &'a str, + pub function: OllamaToolSpec<'a>, +} + +/// Ollama tool specification +#[derive(Debug, Serialize)] +pub struct OllamaToolSpec<'a> { + pub name: &'a str, + pub description: &'a str, + pub parameters: &'a serde_json::Value, +} + +/// Ollama model options +#[derive(Debug, Default, Serialize)] +pub struct OllamaOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub num_predict: Option, // Ollama's equivalent to max_tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, +} + +impl Default for OllamaChatRequest<'_> { + fn default() -> Self { + Self { + model: "", + messages: Vec::new(), + tools: None, + stream: None, + options: None, + } + } +} diff --git a/server/src/provider/ollama/response.rs b/server/src/provider/ollama/response.rs new file mode 100644 index 0000000..a3eea82 --- /dev/null +++ b/server/src/provider/ollama/response.rs @@ -0,0 +1,155 @@ +//! Ollama API response structures + +use serde::Deserialize; + +use crate::{ + db::models::ChatRsToolCall, + provider::{LlmPendingToolCall, LlmStreamChunk, LlmStreamError, LlmTool, LlmUsage}, +}; + +/// Parse Ollama streaming event into LlmStreamChunks, and track tool calls +pub fn parse_ollama_event( + event: OllamaStreamResponse, + tool_calls: &mut Vec, +) -> Vec> { + let mut chunks = Vec::with_capacity(1); + // Handle final message with usage stats + if event.done { + if let Some(usage) = Option::::from(&event) { + chunks.push(Ok(LlmStreamChunk::Usage(usage))); + } + } + + // Handle tool calls in the message + if !event.message.tool_calls.is_empty() { + for (index, tc) in event.message.tool_calls.iter().enumerate() { + let tool_call = LlmPendingToolCall { + index, + tool_name: tc.function.name.clone(), + }; + chunks.push(Ok(LlmStreamChunk::PendingToolCall(tool_call))); + } + tool_calls.extend(event.message.tool_calls); + } + + // Handle text content + if !event.message.content.is_empty() { + chunks.push(Ok(LlmStreamChunk::Text(event.message.content))); + } + + chunks +} + +/// Ollama chat response (streaming) +#[derive(Debug, Deserialize)] +pub struct OllamaStreamResponse { + pub model: String, + pub created_at: String, + pub message: OllamaMessage, + pub done: bool, + #[serde(default)] + pub done_reason: Option, + #[serde(default)] + pub total_duration: Option, + #[serde(default)] + pub load_duration: Option, + #[serde(default)] + pub prompt_eval_count: Option, + #[serde(default)] + pub prompt_eval_duration: Option, + #[serde(default)] + pub eval_count: Option, + #[serde(default)] + pub eval_duration: Option, +} + +/// Ollama completion response (non-streaming) +#[derive(Debug, Deserialize)] +pub struct OllamaCompletionResponse { + pub model: String, + pub created_at: String, + pub response: String, + pub done: bool, + #[serde(default)] + pub done_reason: Option, + #[serde(default)] + pub total_duration: Option, + #[serde(default)] + pub load_duration: Option, + #[serde(default)] + pub prompt_eval_count: Option, + #[serde(default)] + pub prompt_eval_duration: Option, + #[serde(default)] + pub eval_count: Option, + #[serde(default)] + pub eval_duration: Option, +} + +/// Ollama message in response +#[derive(Debug, Deserialize)] +pub struct OllamaMessage { + pub role: String, + #[serde(default)] + pub content: String, + #[serde(default)] + pub tool_calls: Vec, +} + +/// Ollama tool call in response +#[derive(Debug, Deserialize)] +pub struct OllamaToolCall { + pub function: OllamaToolFunction, +} + +/// Ollama tool function in response +#[derive(Debug, Deserialize)] +pub struct OllamaToolFunction { + pub name: String, + pub arguments: serde_json::Value, +} + +impl OllamaToolFunction { + /// Convert to ChatRsToolCall if the tool exists in the provided tools + pub fn convert(self, tools: &[LlmTool]) -> Option { + let tool = tools.iter().find(|t| t.name == self.name)?; + let parameters = serde_json::from_value(self.arguments).ok()?; + + Some(ChatRsToolCall { + id: uuid::Uuid::new_v4().to_string(), + parameters, + tool_id: tool.tool_id, + tool_name: self.name, + tool_type: tool.tool_type, + }) + } +} + +/// Convert Ollama usage to LlmUsage +impl From<&OllamaCompletionResponse> for Option { + fn from(response: &OllamaCompletionResponse) -> Self { + if response.prompt_eval_count.is_some() || response.eval_count.is_some() { + Some(LlmUsage { + input_tokens: response.prompt_eval_count, + output_tokens: response.eval_count, + cost: None, + }) + } else { + None + } + } +} + +impl From<&OllamaStreamResponse> for Option { + fn from(response: &OllamaStreamResponse) -> Self { + if response.prompt_eval_count.is_some() || response.eval_count.is_some() { + Some(LlmUsage { + input_tokens: response.prompt_eval_count, + output_tokens: response.eval_count, + cost: None, + }) + } else { + None + } + } +} From 8f291e9f04372117a318299c773b1e7c8baa0f3b Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 29 Aug 2025 01:08:08 -0400 Subject: [PATCH 41/46] add ollama models --- server/src/provider/ollama.rs | 46 ++++++++++++++++++++++---- server/src/provider/ollama/response.rs | 31 +++++++++++++++++ server/src/provider_models.rs | 24 +++++++------- 3 files changed, 83 insertions(+), 18 deletions(-) diff --git a/server/src/provider/ollama.rs b/server/src/provider/ollama.rs index 00cb93a..65c5ff7 100644 --- a/server/src/provider/ollama.rs +++ b/server/src/provider/ollama.rs @@ -13,7 +13,9 @@ use crate::{ build_ollama_messages, build_ollama_tools, OllamaChatRequest, OllamaCompletionRequest, OllamaOptions, }, - response::{parse_ollama_event, OllamaCompletionResponse, OllamaToolCall}, + response::{ + parse_ollama_event, OllamaCompletionResponse, OllamaModelsResponse, OllamaToolCall, + }, }, utils::get_sse_events, LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmStreamChunk, LlmTool, LlmUsage, @@ -22,6 +24,8 @@ use crate::{ }; const CHAT_API_URL: &str = "/api/chat"; +const COMPLETION_API_URL: &str = "/api/generate"; +const MODELS_API_URL: &str = "/api/tags"; /// Ollama chat provider #[derive(Debug, Clone)] @@ -124,7 +128,7 @@ impl LlmApiProvider for OllamaProvider { }; let response = self .client - .post(format!("{}{}", self.base_url, CHAT_API_URL)) + .post(format!("{}{}", self.base_url, COMPLETION_API_URL)) .header("content-type", "application/json") .json(&request) .send() @@ -154,9 +158,39 @@ impl LlmApiProvider for OllamaProvider { } async fn list_models(&self) -> Result, LlmError> { - // For now, return an empty list since Ollama doesn't appear in models.dev - // In a real implementation, you might call Ollama's /api/tags endpoint - // or maintain a static list of supported models - Ok(Vec::new()) + let response = self + .client + .get(format!("{}{}", self.base_url, MODELS_API_URL)) + .send() + .await + .map_err(|e| LlmError::ProviderError(format!("Ollama models request failed: {}", e)))?; + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + return Err(LlmError::ProviderError(format!( + "Ollama models API error {}: {}", + status, error_text + ))); + } + + let models_response: OllamaModelsResponse = response.json().await.map_err(|e| { + LlmError::ProviderError(format!("Failed to parse models response: {}", e)) + })?; + let models = models_response + .models + .into_iter() + .map(|model| LlmModel { + id: model.name.clone(), + name: model.name, + temperature: Some(true), + tool_call: Some(true), + modified_at: Some(model.modified_at), + format: Some(model.details.format), + family: Some(model.details.family), + ..Default::default() + }) + .collect(); + + Ok(models) } } diff --git a/server/src/provider/ollama/response.rs b/server/src/provider/ollama/response.rs index a3eea82..285b61e 100644 --- a/server/src/provider/ollama/response.rs +++ b/server/src/provider/ollama/response.rs @@ -153,3 +153,34 @@ impl From<&OllamaStreamResponse> for Option { } } } + +/// Ollama models list response +#[derive(Debug, Deserialize)] +pub struct OllamaModelsResponse { + pub models: Vec, +} + +/// Ollama model information +#[derive(Debug, Deserialize)] +pub struct OllamaModelInfo { + pub name: String, + // pub model: String, + pub modified_at: String, + // pub size: u64, + // pub digest: String, + pub details: OllamaModelDetails, +} + +/// Ollama model details +#[derive(Debug, Deserialize)] +pub struct OllamaModelDetails { + // #[serde(default)] + // pub parent_model: String, + pub format: String, + pub family: String, + // #[serde(default)] + // pub families: Vec, + // pub parameter_size: String, + // #[serde(default)] + // pub quantization_level: Option, +} diff --git a/server/src/provider_models.rs b/server/src/provider_models.rs index 2b32d89..6ed5bc8 100644 --- a/server/src/provider_models.rs +++ b/server/src/provider_models.rs @@ -13,29 +13,29 @@ const CACHE_TTL: i64 = 86400; // 1 day in seconds /// A model supported by the LLM provider #[derive(Debug, Default, Clone, JsonSchema, Serialize, Deserialize)] pub struct LlmModel { - id: String, - name: String, + pub id: String, + pub name: String, #[serde(skip_serializing_if = "Option::is_none")] - attachment: Option, + pub attachment: Option, #[serde(skip_serializing_if = "Option::is_none")] - reasoning: Option, + pub reasoning: Option, #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, + pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] - tool_call: Option, + pub tool_call: Option, #[serde(skip_serializing_if = "Option::is_none")] - release_date: Option, + pub release_date: Option, #[serde(skip_serializing_if = "Option::is_none")] - knowledge: Option, + pub knowledge: Option, #[serde(skip_serializing_if = "Option::is_none")] - modalities: Option, + pub modalities: Option, // Ollama fields #[serde(skip_serializing_if = "Option::is_none")] - modified_at: Option, + pub modified_at: Option, #[serde(skip_serializing_if = "Option::is_none")] - format: Option, + pub format: Option, #[serde(skip_serializing_if = "Option::is_none")] - family: Option, + pub family: Option, } #[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)] From 203daabb0a10c0c99ad6468840be3d2736676ba6 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 29 Aug 2025 01:59:34 -0400 Subject: [PATCH 42/46] fix ollama streaming --- server/src/provider/ollama.rs | 27 ++++++++++++++------------- server/src/provider/utils.rs | 12 ++++++++++++ server/src/stream/llm_writer.rs | 2 +- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/server/src/provider/ollama.rs b/server/src/provider/ollama.rs index 65c5ff7..68ef215 100644 --- a/server/src/provider/ollama.rs +++ b/server/src/provider/ollama.rs @@ -8,21 +8,22 @@ use rocket::{async_stream, async_trait, futures::StreamExt}; use crate::{ db::models::ChatRsMessage, provider::{ - ollama::{ - request::{ - build_ollama_messages, build_ollama_tools, OllamaChatRequest, - OllamaCompletionRequest, OllamaOptions, - }, - response::{ - parse_ollama_event, OllamaCompletionResponse, OllamaModelsResponse, OllamaToolCall, - }, - }, - utils::get_sse_events, - LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmStreamChunk, LlmTool, LlmUsage, + utils::get_json_events, LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, + LlmStreamChunk, LlmTool, LlmUsage, }, provider_models::LlmModel, }; +use { + request::{ + build_ollama_messages, build_ollama_tools, OllamaChatRequest, OllamaCompletionRequest, + OllamaOptions, + }, + response::{ + parse_ollama_event, OllamaCompletionResponse, OllamaModelsResponse, OllamaToolCall, + }, +}; + const CHAT_API_URL: &str = "/api/chat"; const COMPLETION_API_URL: &str = "/api/generate"; const MODELS_API_URL: &str = "/api/tags"; @@ -84,9 +85,9 @@ impl LlmApiProvider for OllamaProvider { } let stream = async_stream::stream! { - let mut sse_event_stream = get_sse_events(response); + let mut json_stream = get_json_events(response); let mut tool_calls: Vec = Vec::new(); - while let Some(event) = sse_event_stream.next().await { + while let Some(event) = json_stream.next().await { match event { Ok(event) => { for chunk in parse_ollama_event(event, &mut tool_calls) { diff --git a/server/src/provider/utils.rs b/server/src/provider/utils.rs index 8c2b102..1845ab1 100644 --- a/server/src/provider/utils.rs +++ b/server/src/provider/utils.rs @@ -33,3 +33,15 @@ pub fn get_sse_events( } }) } + +/// Get a stream of deserialized events from a provider JSON stream, not SSE (e.g. Ollama uses this format). +pub fn get_json_events( + response: reqwest::Response, +) -> impl Stream> { + let stream_reader = StreamReader::new(response.bytes_stream().map_err(std::io::Error::other)); + let line_reader = FramedRead::new(stream_reader, LinesCodec::new()); + line_reader.map(|line_result| match line_result { + Ok(line) => serde_json::from_str::(&line).map_err(LlmStreamError::Parsing), + Err(e) => Err(LlmStreamError::Decoding(e)), + }) +} diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 5ca3129..0c3f499 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -16,7 +16,7 @@ use crate::{ }; /// Interval at which chunks are flushed to the Redis stream. -const FLUSH_INTERVAL: Duration = Duration::from_millis(300); +const FLUSH_INTERVAL: Duration = Duration::from_millis(500); /// Max accumulated size of the text chunk before it is automatically flushed to Redis. const MAX_CHUNK_SIZE: usize = 200; /// Expiration in seconds set on the Redis stream (normally, the Redis stream will be deleted before this) From 1808d3b33e0f69be9c4dc37c3db93377ccd837cb Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 29 Aug 2025 02:13:58 -0400 Subject: [PATCH 43/46] web: add Ollama provider --- web/src/components/ProviderManager.tsx | 57 +++++++++-- web/src/hooks/useChatInputState.tsx | 2 +- web/src/lib/api/types.d.ts | 136 +++++++++++++++++++++++-- 3 files changed, 175 insertions(+), 20 deletions(-) diff --git a/web/src/components/ProviderManager.tsx b/web/src/components/ProviderManager.tsx index 85d2139..5e96f7c 100644 --- a/web/src/components/ProviderManager.tsx +++ b/web/src/components/ProviderManager.tsx @@ -1,5 +1,5 @@ import { Bot, Plus, Trash2 } from "lucide-react"; -import { type FormEventHandler, useState } from "react"; +import { type FormEventHandler, useId, useState } from "react"; import { AlertDialog, @@ -39,13 +39,14 @@ import { import type { components } from "@/lib/api/types"; import { cn } from "@/lib/utils"; -type AIProvider = "anthropic" | "openai" | "openrouter" | "lorem"; +type AIProvider = "anthropic" | "openai" | "openrouter" | "ollama" | "lorem"; interface ProviderInfo { name: string; description: string; apiType: components["schemas"]["ChatRsProvider"]["provider_type"]; baseUrl?: string; + customBaseUrl?: boolean; keyFormat?: string; color: string; defaultModel: string; @@ -79,6 +80,16 @@ const PROVIDERS: Record = { color: "bg-blue-100 dark:bg-blue-900 border-blue-300 dark:border-blue-700", defaultModel: "openai/gpt-4o-mini", }, + ollama: { + name: "Ollama", + description: "Run LLMs locally with Ollama", + apiType: "ollama", + baseUrl: "http://localhost:11434", + customBaseUrl: true, + color: + "bg-indigo-100 dark:bg-indigo-900 border-indigo-300 dark:border-indigo-700", + defaultModel: "llama3.2", + }, lorem: { name: "Lorem Ipsum", description: "Generate placeholder text for testing", @@ -103,14 +114,22 @@ export function ProviderManager({ const [selectedProvider, setSelectedProvider] = useState( null, ); + const [name, setName] = useState(""); const [newApiKey, setNewApiKey] = useState(""); + const [newBaseUrl, setNewBaseUrl] = useState(""); + + const nameId = useId(); + const apiKeyId = useId(); + const baseUrlId = useId(); const handleCreateKey: FormEventHandler = (event) => { event.preventDefault(); if ( !selectedProvider || - (PROVIDERS[selectedProvider].apiType !== "lorem" && !newApiKey.trim()) + (PROVIDERS[selectedProvider].apiType !== "lorem" && + PROVIDERS[selectedProvider].apiType !== "ollama" && + !newApiKey.trim()) ) return; @@ -118,7 +137,7 @@ export function ProviderManager({ { name, type: PROVIDERS[selectedProvider].apiType, - base_url: PROVIDERS[selectedProvider].baseUrl, + base_url: newBaseUrl || PROVIDERS[selectedProvider].baseUrl, default_model: PROVIDERS[selectedProvider].defaultModel, api_key: newApiKey || null, }, @@ -127,6 +146,7 @@ export function ProviderManager({ setSelectedProvider(null); setIsCreateDialogOpen(false); setNewApiKey(""); + setNewBaseUrl(""); setName(""); }, }, @@ -204,19 +224,21 @@ export function ProviderManager({ {PROVIDERS[provider].description} -
- {PROVIDERS[provider].keyFormat} -
+ {PROVIDERS[provider].keyFormat && ( +
+ {PROVIDERS[provider].keyFormat} +
+ )} ))}
- + {selectedProvider && PROVIDERS[selectedProvider].keyFormat && (
- +
)} + {selectedProvider && + PROVIDERS[selectedProvider].customBaseUrl && ( +
+ + setNewBaseUrl(e.target.value)} + /> +
+ )}
+ )}