diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index c3e3e29..a9dddf5 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -2,7 +2,7 @@ ## Overview -RsChat is a real-time chat application that provides resumable streaming conversations with LLM providers. The architecture is designed for high performance, scalability across multiple server instances, and resilient streaming that can survive network interruptions. +RsChat is an application for chatting with multiple LLM providers. The architecture is designed for high performance, scalability across multiple server instances, and resilient streaming that can survive network interruptions. ## Core Architecture @@ -14,7 +14,7 @@ RsChat is a real-time chat application that provides resumable streaming convers ### Backend (Rust/Rocket) - **Location**: `server/` -- **Framework**: Rocket with async/await support +- **Framework**: Rocket - **Database**: PostgreSQL for persistent storage - **Cache/Streaming**: Redis for stream management and caching @@ -24,8 +24,8 @@ RsChat is a real-time chat application that provides resumable streaming convers RsChat uses a hybrid streaming architecture that provides both real-time performance and cross-instance resumability: -1. **Server**: Redis Streams for resumability and multi-instance support -2. **Client**: Server-Sent Events (SSE) read from the Redis streams +1. **Server**: Redis Streams created from the provider responses, for resumability and multi-instance support, as well as simultaneous streaming to multiple clients. +2. **Client**: Server-Sent Events (SSE) read from the Redis streams. ### Key Components @@ -34,9 +34,8 @@ RsChat uses a hybrid streaming architecture that provides both real-time perform The core component that processes LLM provider streams and manages Redis stream output. **Key Features:** -- **Batching**: Accumulates chunks from the provider stream, up to a max length or timeout +- **Batching**: Accumulates chunks from the provider stream, up to a max length or timeout, and adds them to the Redis stream - **Background Pings**: Sends regular keepalive pings -- **Database Integration**: Saves final responses to PostgreSQL #### 2. Redis and SSE Stream Structure @@ -46,6 +45,7 @@ The core component that processes LLM provider streams and manages Redis stream - `start`: Stream initialization - `text`: Accumulated text chunks - `tool_call`: LLM tool invocations (JSON stringified) +- `pending_tool_call`: Pending tool call invocations - `error`: Error messages - `ping`: Keepalive messages - `end`: Stream completion @@ -54,7 +54,7 @@ The core component that processes LLM provider streams and manages Redis stream #### 3. Stream Lifecycle ``` -Client Request → SSE Connection → LlmStreamWriter.create() +Client Request → LLM streaming response → LlmStreamWriter.create() ↓ LLM Provider Stream → Batching Data Chunks → Redis XADD ↓ @@ -77,9 +77,8 @@ Stream End → Database Save → Redis DEL ``` Client → POST /api/chat/{session_id} → Send request to LLM Provider - → SSE Response Stream created - → LlmStreamWriter.create() - → Redis Stream created + → LLM response received, streamed to Redis with the `LlmStreamWriter` + → GET /api/chat/{session_id}/stream to connect to the stream and stream the response ``` ### 2. Stream Processing @@ -87,7 +86,7 @@ Client → POST /api/chat/{session_id} LLM Chunk → Process text, tool calls, usage, and error chunks → Batching Logic → Redis XADD (if conditions met) - → Continue SSE Stream + → Client(s) receive the new chunks ``` ### 3. Stream Completion @@ -101,5 +100,5 @@ LLM End → Final Database Save ### 4. Reconnection/Resume ``` Client Reconnect → Check ongoing streams via GET /api/chat/streams - → Reconnect to stream (if active) + → Reconnect to any active streams ``` diff --git a/Dockerfile b/Dockerfile index 233adbf..8470909 100644 --- a/Dockerfile +++ b/Dockerfile @@ -44,7 +44,7 @@ RUN apt-get update -qq && \ apt-get install -y -qq ca-certificates libpq5 && \ apt-get clean -# Use non-root user +# Create non-root user and data directory ARG UID=10001 RUN adduser \ --disabled-password \ @@ -53,6 +53,9 @@ RUN adduser \ --shell "/sbin/nologin" \ --uid "${UID}" \ appuser +RUN mkdir -p /data +RUN chown -R appuser:appuser /data + USER appuser # Copy app files @@ -61,6 +64,7 @@ COPY --from=backend-build /app/run-server /usr/local/bin/ # Run ENV RS_CHAT_STATIC_PATH=/var/www +ENV RS_CHAT_DATA_DIR=/data ENV RS_CHAT_ADDRESS=0.0.0.0 ENV RS_CHAT_PORT=8080 EXPOSE 8080 diff --git a/README.md b/README.md index bb44dd1..570d6cd 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,7 @@ A fast, secure, self-hostable chat application built with Rust, TypeScript, and React. Chat with multiple AI providers using your own API keys, with real-time streaming built-in. -!! **Submission to the [T3 Chat Cloneathon](https://cloneathon.t3.chat/)** !! - -Demo link: https://rschat.fasharp.io (⚠️ This is a demo - don't expect your account/chats to be there when you come back. It may intermittently delete data. Please also don't enter any sensitive information or confidential data) +Demo link: https://rs-chat-demo.up.railway.app/ (⚠️ This is a demo - don't expect your account/chats to be there when you come back. It may intermittently delete all data. Please also don't enter any sensitive information or confidential data) ## ✨ Features diff --git a/biome.json b/biome.json index 1653499..840301a 100644 --- a/biome.json +++ b/biome.json @@ -1,5 +1,5 @@ { - "$schema": "https://biomejs.dev/schemas/2.0.0/schema.json", + "$schema": "https://biomejs.dev/schemas/2.2.2/schema.json", "vcs": { "enabled": true, "clientKind": "git", diff --git a/docker-compose.yml b/docker-compose.yml index 251c5d0..269973a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -29,9 +29,11 @@ services: RS_CHAT_SERVER_ADDRESS: http://localhost:8080 RS_CHAT_DATABASE_URL: postgres://postgres:postgres@postgres/postgres RS_CHAT_REDIS_URL: redis://redis:6379 + RS_CHAT_DATA_DIR: /data env_file: server/.env volumes: - ./.docker:/certs + - rschat_data:/data depends_on: - db - redis @@ -39,3 +41,4 @@ services: volumes: postgres_data: redis_data: + rschat_data: diff --git a/server/.env.example b/server/.env.example index 9c6a039..69f7954 100644 --- a/server/.env.example +++ b/server/.env.example @@ -10,5 +10,8 @@ RS_CHAT_GITHUB_CLIENT_SECRET=your_github_client_secret_here # You can generate one with: openssl rand -hex 32 RS_CHAT_SECRET_KEY=hex-secret-key-for-encryption-change-this +# Local data directory +RS_CHAT_DATA_DIR=.local + # Postgres URL for running migrations via the Diesel CLI DATABASE_URL=postgres://postgres:postgres@localhost/postgres diff --git a/server/.gitignore b/server/.gitignore new file mode 100644 index 0000000..569d9e1 --- /dev/null +++ b/server/.gitignore @@ -0,0 +1,2 @@ +# Local storage files +.local/ diff --git a/server/Cargo.lock b/server/Cargo.lock index 860b6fd..1f3a8a8 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -373,6 +373,7 @@ version = "0.6.0" dependencies = [ "aes-gcm", "astral-tokio-tar", + "base64 0.22.1", "bollard", "chrono", "const_format", diff --git a/server/Cargo.toml b/server/Cargo.toml index 3185807..35a060c 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -18,6 +18,7 @@ strip = true [dependencies] aes-gcm = "0.10.3" astral-tokio-tar = "0.5.2" +base64 = "0.22.1" bollard = { version = "0.19.1", features = ["ssl"] } chrono = { version = "0.4.41", features = ["serde"] } const_format = "0.2.34" diff --git a/server/migrations/2025-08-31-034235_add_files/down.sql b/server/migrations/2025-08-31-034235_add_files/down.sql new file mode 100644 index 0000000..38a7300 --- /dev/null +++ b/server/migrations/2025-08-31-034235_add_files/down.sql @@ -0,0 +1 @@ +DROP TABLE files; diff --git a/server/migrations/2025-08-31-034235_add_files/up.sql b/server/migrations/2025-08-31-034235_add_files/up.sql new file mode 100644 index 0000000..bbbc781 --- /dev/null +++ b/server/migrations/2025-08-31-034235_add_files/up.sql @@ -0,0 +1,18 @@ +CREATE TABLE files ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users (id), + session_id UUID REFERENCES chat_sessions (id) ON UPDATE CASCADE ON DELETE SET NULL, + path TEXT NOT NULL, + file_type TEXT NOT NULL, + content_type TEXT NOT NULL, + size INTEGER NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +SELECT + diesel_manage_updated_at ('files'); + +CREATE INDEX idx_files_user_id ON files (user_id); + +CREATE INDEX idx_files_session_id ON files (session_id); diff --git a/server/src/api.rs b/server/src/api.rs index 8033fe7..3c70abb 100644 --- a/server/src/api.rs +++ b/server/src/api.rs @@ -5,6 +5,7 @@ mod info; mod provider; mod secret; mod session; +mod storage; mod tool; pub use api_key::get_routes as api_key_routes; @@ -14,4 +15,5 @@ pub use info::get_routes as info_routes; pub use provider::get_routes as provider_routes; pub use secret::get_routes as secret_routes; pub use session::get_routes as session_routes; +pub use storage::get_routes as storage_routes; pub use tool::get_routes as tool_routes; diff --git a/server/src/api/auth.rs b/server/src/api/auth.rs index ce8bbf2..9cf1141 100644 --- a/server/src/api/auth.rs +++ b/server/src/api/auth.rs @@ -12,18 +12,8 @@ use rocket_okapi::{ use schemars::JsonSchema; use crate::{ - auth::{ - ChatRsAuthSession, DiscordOAuthConfig, GitHubOAuthConfig, GoogleOAuthConfig, OIDCConfig, - SSOHeaderMergedConfig, - }, - db::{ - models::ChatRsUser, - services::{ - ApiKeyDbService, ChatDbService, ProviderDbService, SecretDbService, ToolDbService, - UserDbService, - }, - DbConnection, - }, + auth::*, + db::{models::ChatRsUser, services::*, DbConnection}, errors::ApiError, }; diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index fbc44c5..aa1e94a 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -18,21 +18,16 @@ use crate::{ api::session::DEFAULT_SESSION_TITLE, auth::ChatRsUserId, db::{ - models::{ - AssistantMeta, ChatRsMessageMeta, ChatRsMessageRole, ChatRsSessionMeta, - NewChatRsMessage, UpdateChatRsSession, - }, + models::*, services::{ChatDbService, ProviderDbService, ToolDbService}, DbConnection, DbPool, }, errors::ApiError, - provider::{build_llm_provider_api, LlmError, LlmProviderOptions}, + provider::{build_llm_messages, build_llm_provider_api, LlmError, LlmProviderOptions}, redis::{ExclusiveRedisClient, RedisClient}, - stream::{ - cancel_current_chat_stream, check_chat_stream_exists, get_current_chat_streams, - LastEventId, LlmStreamWriter, SseStreamReader, - }, - tools::{get_llm_tools_from_input, SendChatToolInput}, + storage::LocalStorage, + stream::*, + tools::SendChatToolInput, utils::{generate_title, Encryptor}, }; @@ -72,6 +67,8 @@ pub struct SendChatInput<'a> { options: LlmProviderOptions, /// Configuration of tools available to the assistant tools: Option, + /// IDs of the file(s) to attach to this message + files: Option>, } #[derive(JsonSchema, serde::Serialize)] @@ -92,6 +89,7 @@ pub async fn send_chat_stream( redis: RedisClient, redis_writer: ExclusiveRedisClient, encryptor: &State, + storage: &State, http_client: &State, session_id: Uuid, mut input: Json>, @@ -123,12 +121,26 @@ pub async fn send_chat_stream( // Get the user's chosen tools let mut tools = None; - if let Some(tool_input) = input.tools.as_ref() { - let mut tool_db_service = ToolDbService::new(&mut db); - tools = Some(get_llm_tools_from_input(&user_id, tool_input, &mut tool_db_service).await?); + if let Some(conf) = input.tools.take() { + let llm_tools = conf + .get_llm_tools(&user_id, &mut ToolDbService::new(&mut db)) + .await?; + tools = Some(llm_tools); + + // Update session metadata with new tools if needed + if session.meta.tool_config.as_ref().is_none_or(|c| *c != conf) { + let data = UpdateChatRsSession { + meta: Some(ChatRsSessionMeta::new(Some(conf))), + ..Default::default() + }; + ChatDbService::new(&mut db) + .update_session(&user_id, &session_id, data) + .await?; + } } // Generate session title if needed, and save user message to database + let attached_file_ids = input.files.take(); if let Some(user_message) = &input.message { if messages.is_empty() && session.title == DEFAULT_SESSION_TITLE { generate_title( @@ -140,47 +152,34 @@ pub async fn send_chat_stream( db_pool, ); } - let new_message = ChatDbService::new(&mut db) + let message_meta = attached_file_ids + .map(|ids| ChatRsMessageMeta::new_user(UserMeta { files: Some(ids) })) + .unwrap_or_default(); + let message = ChatDbService::new(&mut db) .save_message(NewChatRsMessage { content: user_message, session_id: &session_id, role: ChatRsMessageRole::User, - meta: ChatRsMessageMeta::default(), + meta: message_meta, }) .await?; - messages.push(new_message); + messages.push(message); } - // Update session metadata if needed - if let Some(tool_input) = input.tools.take() { - if session - .meta - .tool_config - .is_none_or(|config| config != tool_input) - { - let meta = ChatRsSessionMeta::new(Some(tool_input)); - let data = UpdateChatRsSession { - meta: Some(&meta), - ..Default::default() - }; - ChatDbService::new(&mut db) - .update_session(&user_id, &session_id, data) - .await?; - } - } - - // Get the provider's stream response + // Convert the messages, and get the provider's response stream + let llm_messages = + build_llm_messages(messages, &user_id, &session_id, &mut db, &storage).await?; let stream = provider_api - .chat_stream(messages, tools, &input.options) + .chat_stream(llm_messages, tools, &input.options) .await?; - let provider_id = input.provider_id; - let provider_options = input.options.clone(); // Create the Redis stream let mut stream_writer = LlmStreamWriter::new(redis_writer, &user_id, &session_id); stream_writer.start().await?; // Spawn a task to stream and save the response + let provider_id = input.provider_id.clone(); + let provider_options = input.options.clone(); tokio::spawn(async move { let (text, tool_calls, usage, errors, cancelled) = stream_writer.process(stream).await; let assistant_meta = AssistantMeta { diff --git a/server/src/api/provider.rs b/server/src/api/provider.rs index e66fa02..9ea73d4 100644 --- a/server/src/api/provider.rs +++ b/server/src/api/provider.rs @@ -8,16 +8,12 @@ use uuid::Uuid; use crate::{ auth::ChatRsUserId, db::{ - models::{ - ChatRsProvider, ChatRsProviderType, NewChatRsProvider, NewChatRsSecret, - UpdateChatRsProvider, UpdateChatRsSecret, - }, + models::*, services::{ProviderDbService, SecretDbService}, DbConnection, }, errors::ApiError, - provider::build_llm_provider_api, - provider_models::LlmModel, + provider::{build_llm_provider_api, models::LlmModel}, redis::RedisClient, utils::Encryptor, }; diff --git a/server/src/api/storage.rs b/server/src/api/storage.rs new file mode 100644 index 0000000..c39282d --- /dev/null +++ b/server/src/api/storage.rs @@ -0,0 +1,108 @@ +use std::path::PathBuf; + +use rocket::{delete, fs::NamedFile, get, post, serde::json::Json, Route, State}; +use rocket_okapi::{ + okapi::openapi3::OpenApi, openapi, openapi_get_routes_spec, settings::OpenApiSettings, +}; +use uuid::Uuid; + +use crate::{ + auth::ChatRsUserId, + db::{ + models::{ChatRsFile, NewChatRsFile}, + services::FileDbService, + DbConnection, + }, + errors::ApiError, + storage::{FileData, LocalStorage}, +}; + +pub fn get_routes(settings: &OpenApiSettings) -> (Vec, OpenApi) { + openapi_get_routes_spec![settings: upload_file, download_file, list_session_files, delete_file] +} + +/// List session files +#[openapi(tag = "Storage")] +#[get("/")] +async fn list_session_files( + user_id: ChatRsUserId, + session_id: Uuid, + mut db: DbConnection, +) -> Result>, ApiError> { + let files = FileDbService::new(&mut db) + .list_session_files(&user_id, &session_id) + .await?; + + Ok(Json(files)) +} + +/// Upload a new session file +#[openapi(tag = "Storage")] +#[post("//", data = "")] +async fn upload_file( + user_id: ChatRsUserId, + session_id: Uuid, + path: PathBuf, + file: FileData<'_>, + storage: &State, + mut db: DbConnection, +) -> Result, ApiError> { + let size = storage + .create_file(&user_id, Some(&session_id), &path, file.data) + .await?; + let db_file = FileDbService::new(&mut db) + .create_session_file(NewChatRsFile { + user_id: &user_id, + session_id: Some(&session_id), + path: &path.to_string_lossy(), + file_type: file.file_type.into(), + content_type: &file.content_type.to_string(), + size: size.try_into().unwrap_or_default(), + }) + .await?; + + Ok(Json(db_file)) +} + +/// Download a session file +#[openapi(tag = "Storage")] +#[get("//")] +async fn download_file( + user_id: ChatRsUserId, + session_id: Uuid, + file_id: Uuid, + storage: &State, + mut db: DbConnection, +) -> Result { + let file = FileDbService::new(&mut db) + .find_session_file(&user_id, &session_id, &file_id) + .await?; + let file_path = storage.get_file_path(&user_id, Some(&session_id), &file.path)?; + + Ok(NamedFile::open(file_path).await?) +} + +/// Delete a session file +#[openapi(tag = "Storage")] +#[delete("//")] +async fn delete_file( + user_id: ChatRsUserId, + session_id: Uuid, + file_id: Uuid, + storage: &State, + mut db: DbConnection, +) -> Result { + let mut db_service = FileDbService::new(&mut db); + let file = db_service + .find_session_file(&user_id, &session_id, &file_id) + .await?; + + storage + .delete_file(&user_id, Some(&session_id), &file.path) + .await?; + db_service + .delete_session_file(&user_id, &session_id, &file_id) + .await?; + + Ok(file_id.to_string()) +} diff --git a/server/src/api/tool.rs b/server/src/api/tool.rs index b9f9240..e2053b3 100644 --- a/server/src/api/tool.rs +++ b/server/src/api/tool.rs @@ -18,20 +18,15 @@ use uuid::Uuid; use crate::{ api::secret::SecretInput, auth::ChatRsUserId, + config::AppConfig, db::{ - models::{ - ChatRsExecutedToolCall, ChatRsExternalApiTool, ChatRsMessageMeta, ChatRsMessageRole, - ChatRsSystemTool, NewChatRsExternalApiTool, NewChatRsMessage, NewChatRsSecret, - NewChatRsSystemTool, - }, + models::*, services::{ChatDbService, SecretDbService, ToolDbService}, DbConnection, }, errors::ApiError, provider::LlmToolType, - tools::{ - ChatRsExternalApiToolConfig, ChatRsSystemToolConfig, ToolError, ToolLog, ToolResponseFormat, - }, + tools::*, utils::{Encryptor, SenderWithLogging}, }; @@ -147,6 +142,7 @@ async fn execute_tool( user_id: ChatRsUserId, mut db: DbConnection, http_client: &State, + app_config: &State, encryptor: &State, message_id: Uuid, tool_call_id: &str, @@ -188,6 +184,7 @@ async fn execute_tool( }; let (streaming_tx, streaming_rx) = tokio::sync::mpsc::channel(50); + let app_config = app_config.inner().clone(); let http_client = http_client.inner().clone(); let secrets = secret_1.into_iter().collect::>(); @@ -217,7 +214,7 @@ async fn execute_tool( let tool_result = match (system_tool, external_api_tool) { (Some(system_tool), None) => { system_tool - .build_executor() + .build_executor(&mut db, &app_config, &message.session_id) .validate_and_execute( &tool_call.tool_name, &tool_call.parameters, diff --git a/server/src/auth/guard.rs b/server/src/auth/guard.rs index 58787fd..ebd3487 100644 --- a/server/src/auth/guard.rs +++ b/server/src/auth/guard.rs @@ -24,7 +24,7 @@ use crate::{ }; /// User ID request guard to ensure a logged-in user. -pub struct ChatRsUserId(pub Uuid); +pub struct ChatRsUserId(pub(super) Uuid); impl Deref for ChatRsUserId { type Target = Uuid; @@ -40,14 +40,6 @@ impl<'r> FromRequest<'r> for ChatRsUserId { type Error = &'r str; async fn from_request(req: &'r rocket::Request<'_>) -> Outcome { - // Try authentication via proxy headers if configured - if let Some(config) = req.rocket().state::() { - if let Some(proxy_user) = get_sso_user_from_headers(config, req.headers()) { - let mut db = try_outcome!(req.guard::().await); - return get_sso_auth_outcome(&proxy_user, config, &mut db).await; - } - }; - // Try authentication via API key if let Some(auth_header) = req.headers().get_one("Authorization") { let encryptor = req.rocket().state::().expect("should exist"); @@ -57,10 +49,21 @@ impl<'r> FromRequest<'r> for ChatRsUserId { // Try authentication via session let session = try_outcome!(req.guard::>().await); - match session.tap(|data| data.and_then(|auth_session| auth_session.user_id())) { - Some(user_id) => Outcome::Success(ChatRsUserId(user_id)), - None => Outcome::Error((Status::Unauthorized, "Unauthorized")), + if let Some(user_id) = + session.tap(|data| data.and_then(|auth_session| auth_session.user_id())) + { + return Outcome::Success(ChatRsUserId(user_id)); } + + // Try authentication via proxy headers if configured + if let Some(config) = req.rocket().state::() { + if let Some(proxy_user) = get_sso_user_from_headers(config, req.headers()) { + let mut db = try_outcome!(req.guard::().await); + return get_sso_auth_outcome(&proxy_user, config, &mut db).await; + } + }; + + Outcome::Error((Status::Unauthorized, "Unauthorized")) } } diff --git a/server/src/config.rs b/server/src/config.rs index 0fc5183..25d1f3c 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -8,7 +8,7 @@ use rocket::{ use serde::{Deserialize, Serialize}; /// Main server config (settings are merged with Rocket's default config) -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct AppConfig { /// 32-byte hex string (64 characters) used for encrypting cookies and API keys pub secret_key: String, @@ -16,6 +16,8 @@ pub struct AppConfig { pub server_address: String, /// Static files directory (default: "../web/dist") pub static_path: Option, + /// Local data directory (default: "/data") + pub data_dir: Option, /// Postgres Database URL pub database_url: String, /// Redis connection URL diff --git a/server/src/db/models.rs b/server/src/db/models.rs index fbb198b..73f3d74 100644 --- a/server/src/db/models.rs +++ b/server/src/db/models.rs @@ -1,5 +1,6 @@ mod api_key; mod chat; +mod file; mod provider; mod secret; mod tool; @@ -9,6 +10,7 @@ use crate::db::schema; pub use api_key::*; pub use chat::*; +pub use file::*; pub use provider::*; pub use secret::*; pub use tool::*; diff --git a/server/src/db/models/api_key.rs b/server/src/db/models/api_key.rs index 163bbf2..e87f42b 100644 --- a/server/src/db/models/api_key.rs +++ b/server/src/db/models/api_key.rs @@ -1,8 +1,5 @@ use chrono::{DateTime, Utc}; -use diesel::{ - prelude::{Associations, Identifiable, Insertable, Queryable}, - Selectable, -}; +use diesel::prelude::*; use schemars::JsonSchema; use uuid::Uuid; diff --git a/server/src/db/models/chat.rs b/server/src/db/models/chat.rs index ad3db9e..1b27574 100644 --- a/server/src/db/models/chat.rs +++ b/server/src/db/models/chat.rs @@ -1,8 +1,5 @@ use chrono::{DateTime, Utc}; -use diesel::{ - prelude::{AsChangeset, Associations, Identifiable, Insertable, Queryable}, - Selectable, -}; +use diesel::prelude::*; use diesel_as_jsonb::AsJsonb; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -14,7 +11,7 @@ use crate::{ tools::SendChatToolInput, }; -#[derive(Identifiable, Associations, Queryable, Selectable, JsonSchema, serde::Serialize)] +#[derive(Identifiable, Associations, Queryable, Selectable, JsonSchema, Serialize)] #[diesel(belongs_to(ChatRsUser, foreign_key = user_id))] #[diesel(table_name = super::schema::chat_sessions)] pub struct ChatRsSession { @@ -50,12 +47,12 @@ pub struct NewChatRsSession<'r> { #[diesel(table_name = super::schema::chat_sessions)] pub struct UpdateChatRsSession<'r> { pub title: Option<&'r str>, - pub meta: Option<&'r ChatRsSessionMeta>, + pub meta: Option, } #[derive(diesel_derive_enum::DbEnum)] #[db_enum(existing_type_path = "crate::db::schema::sql_types::ChatMessageRole")] -#[derive(Debug, PartialEq, Eq, JsonSchema, serde::Serialize)] +#[derive(Debug, PartialEq, Eq, JsonSchema, Serialize)] pub enum ChatRsMessageRole { User, Assistant, @@ -63,7 +60,7 @@ pub enum ChatRsMessageRole { Tool, } -#[derive(Identifiable, Queryable, Selectable, Associations, JsonSchema, serde::Serialize)] +#[derive(Identifiable, Queryable, Selectable, Associations, JsonSchema, Serialize)] #[diesel(belongs_to(ChatRsSession, foreign_key = session_id))] #[diesel(table_name = super::schema::chat_messages)] pub struct ChatRsMessage { @@ -77,6 +74,9 @@ pub struct ChatRsMessage { #[derive(Debug, Default, JsonSchema, Serialize, Deserialize, AsJsonb)] pub struct ChatRsMessageMeta { + /// User messages: metadata associated with the user message + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, /// Assistant messages: metadata associated with the assistant message #[serde(skip_serializing_if = "Option::is_none")] pub assistant: Option, @@ -85,14 +85,27 @@ pub struct ChatRsMessageMeta { pub tool_call: Option, } impl ChatRsMessageMeta { - pub fn new_assistant(assistant: AssistantMeta) -> Self { + pub fn new_assistant(assistant_meta: AssistantMeta) -> Self { + Self { + assistant: Some(assistant_meta), + ..Default::default() + } + } + pub fn new_user(user_meta: UserMeta) -> Self { Self { - assistant: Some(assistant), - tool_call: None, + user: Some(user_meta), + ..Default::default() } } } +#[derive(Debug, Default, JsonSchema, Serialize, Deserialize)] +pub struct UserMeta { + /// The IDs of the files attached to this message + #[serde(skip_serializing_if = "Option::is_none")] + pub files: Option>, +} + #[derive(Debug, Default, JsonSchema, Serialize, Deserialize)] pub struct AssistantMeta { /// The ID of the LLM provider used to generate this message diff --git a/server/src/db/models/file.rs b/server/src/db/models/file.rs new file mode 100644 index 0000000..72613be --- /dev/null +++ b/server/src/db/models/file.rs @@ -0,0 +1,68 @@ +use chrono::{DateTime, Utc}; +use diesel::prelude::*; +use schemars::JsonSchema; +use uuid::Uuid; + +use crate::{db::models::ChatRsUser, provider::LlmError}; + +#[derive(Identifiable, Associations, Queryable, Selectable, JsonSchema, serde::Serialize)] +#[diesel(belongs_to(ChatRsUser, foreign_key = user_id))] +#[diesel(table_name = super::schema::files)] +pub struct ChatRsFile { + pub id: Uuid, + #[serde(skip)] + pub user_id: Uuid, + #[serde(skip)] + pub session_id: Option, + pub path: String, + #[schemars(with = "ChatRsFileType")] + pub file_type: String, + pub content_type: String, + pub size: i32, + pub created_at: DateTime, + #[serde(skip)] + pub updated_at: DateTime, +} + +#[derive(Insertable)] +#[diesel(table_name = super::schema::files)] +pub struct NewChatRsFile<'r> { + pub user_id: &'r Uuid, + pub session_id: Option<&'r Uuid>, + pub path: &'r str, + pub file_type: &'r str, + pub content_type: &'r str, + pub size: i32, +} + +/// File modality +#[derive(Debug, PartialEq, Eq, Hash, JsonSchema)] +#[serde(rename_all = "lowercase")] +pub enum ChatRsFileType { + Text, + Image, + Pdf, +} + +impl TryFrom<&str> for ChatRsFileType { + type Error = LlmError; + + fn try_from(file_type: &str) -> Result { + match file_type { + "text" => Ok(ChatRsFileType::Text), + "image" => Ok(ChatRsFileType::Image), + "pdf" => Ok(ChatRsFileType::Pdf), + _ => Err(LlmError::InvalidFileType(file_type.to_owned())), + } + } +} + +impl From for &'static str { + fn from(file_type: ChatRsFileType) -> Self { + match file_type { + ChatRsFileType::Text => "text", + ChatRsFileType::Image => "image", + ChatRsFileType::Pdf => "pdf", + } + } +} diff --git a/server/src/db/models/provider.rs b/server/src/db/models/provider.rs index 6ce9af7..4dda182 100644 --- a/server/src/db/models/provider.rs +++ b/server/src/db/models/provider.rs @@ -1,8 +1,5 @@ use chrono::{DateTime, Utc}; -use diesel::{ - prelude::{AsChangeset, Associations, Identifiable, Insertable, Queryable}, - Selectable, -}; +use diesel::prelude::*; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -17,7 +14,7 @@ pub struct ChatRsProvider { pub name: String, #[schemars(with = "ChatRsProviderType")] pub provider_type: String, - #[serde(skip_serializing)] + #[serde(skip)] pub user_id: Uuid, pub default_model: String, pub base_url: Option, diff --git a/server/src/db/models/secret.rs b/server/src/db/models/secret.rs index 5e34929..5bf48d8 100644 --- a/server/src/db/models/secret.rs +++ b/server/src/db/models/secret.rs @@ -1,8 +1,5 @@ use chrono::{DateTime, Utc}; -use diesel::{ - prelude::{AsChangeset, Associations, Identifiable, Insertable, Queryable}, - Selectable, -}; +use diesel::prelude::*; use schemars::JsonSchema; use uuid::Uuid; diff --git a/server/src/db/models/tool.rs b/server/src/db/models/tool.rs index a328ea5..a2bbc9f 100644 --- a/server/src/db/models/tool.rs +++ b/server/src/db/models/tool.rs @@ -1,10 +1,7 @@ use std::collections::HashMap; use chrono::{DateTime, Utc}; -use diesel::{ - prelude::{Associations, Identifiable, Insertable, Queryable}, - Selectable, -}; +use diesel::prelude::*; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use uuid::Uuid; diff --git a/server/src/db/models/user.rs b/server/src/db/models/user.rs index e9c5c02..7add663 100644 --- a/server/src/db/models/user.rs +++ b/server/src/db/models/user.rs @@ -1,8 +1,5 @@ use chrono::{DateTime, Utc}; -use diesel::{ - prelude::{AsChangeset, Identifiable, Insertable, Queryable}, - Selectable, -}; +use diesel::prelude::*; use schemars::JsonSchema; use uuid::Uuid; diff --git a/server/src/db/schema.rs b/server/src/db/schema.rs index 7ce8c07..3c39fd1 100644 --- a/server/src/db/schema.rs +++ b/server/src/db/schema.rs @@ -59,6 +59,20 @@ diesel::table! { } } +diesel::table! { + files (id) { + id -> Uuid, + user_id -> Uuid, + session_id -> Nullable, + path -> Text, + file_type -> Text, + content_type -> Text, + size -> Int4, + created_at -> Timestamptz, + updated_at -> Timestamptz, + } +} + diesel::table! { providers (id) { id -> Int4, @@ -124,6 +138,8 @@ diesel::joinable!(app_api_keys -> users (user_id)); diesel::joinable!(chat_messages -> chat_sessions (session_id)); diesel::joinable!(chat_sessions -> users (user_id)); diesel::joinable!(external_api_tools -> users (user_id)); +diesel::joinable!(files -> chat_sessions (session_id)); +diesel::joinable!(files -> users (user_id)); diesel::joinable!(providers -> secrets (api_key_id)); diesel::joinable!(providers -> users (user_id)); diesel::joinable!(secrets -> users (user_id)); @@ -135,6 +151,7 @@ diesel::allow_tables_to_appear_in_same_query!( chat_messages, chat_sessions, external_api_tools, + files, providers, secrets, system_tools, diff --git a/server/src/db/services.rs b/server/src/db/services.rs index 236bec5..1e5ef73 100644 --- a/server/src/db/services.rs +++ b/server/src/db/services.rs @@ -1,5 +1,6 @@ mod api_key; mod chat; +mod file; mod provider; mod secret; mod tool; @@ -7,6 +8,7 @@ mod user; pub use api_key::ApiKeyDbService; pub use chat::ChatDbService; +pub use file::FileDbService; pub use provider::ProviderDbService; pub use secret::SecretDbService; pub use tool::ToolDbService; diff --git a/server/src/db/services/file.rs b/server/src/db/services/file.rs new file mode 100644 index 0000000..8c2928b --- /dev/null +++ b/server/src/db/services/file.rs @@ -0,0 +1,73 @@ +use diesel::prelude::*; +use diesel_async::RunQueryDsl; +use uuid::Uuid; + +use crate::db::{ + models::{ChatRsFile, NewChatRsFile}, + schema::files, + DbConnection, +}; + +pub struct FileDbService<'a> { + pub db: &'a mut DbConnection, +} + +impl<'a> FileDbService<'a> { + pub fn new(db: &'a mut DbConnection) -> Self { + Self { db } + } + + pub async fn create_session_file( + &mut self, + file: NewChatRsFile<'_>, + ) -> QueryResult { + diesel::insert_into(files::table) + .values(file) + .returning(ChatRsFile::as_returning()) + .get_result(self.db) + .await + } + + pub async fn find_session_file( + &mut self, + user_id: &Uuid, + session_id: &Uuid, + file_id: &Uuid, + ) -> QueryResult { + files::table + .filter(files::user_id.eq(user_id)) + .filter(files::session_id.eq(session_id)) + .filter(files::id.eq(file_id)) + .select(ChatRsFile::as_select()) + .first(self.db) + .await + } + + pub async fn list_session_files( + &mut self, + user_id: &Uuid, + session_id: &Uuid, + ) -> QueryResult> { + files::table + .filter(files::user_id.eq(user_id)) + .filter(files::session_id.eq(session_id)) + .select(ChatRsFile::as_select()) + .load(self.db) + .await + } + + pub async fn delete_session_file( + &mut self, + user_id: &Uuid, + session_id: &Uuid, + file_id: &Uuid, + ) -> QueryResult { + diesel::delete(files::table) + .filter(files::user_id.eq(user_id)) + .filter(files::session_id.eq(session_id)) + .filter(files::id.eq(file_id)) + .returning(files::id) + .get_result(self.db) + .await + } +} diff --git a/server/src/errors.rs b/server/src/errors.rs index 7adefed..41ca5f0 100644 --- a/server/src/errors.rs +++ b/server/src/errors.rs @@ -24,6 +24,8 @@ pub enum ApiError { Chat(#[from] LlmError), #[error(transparent)] Tool(#[from] ToolError), + #[error(transparent)] + Io(#[from] std::io::Error), } #[derive(Debug, JsonSchema, serde::Serialize)] @@ -76,6 +78,10 @@ impl<'r, 'o: 'r> response::Responder<'r, 'o> for ApiError { ApiErrorResponse::BadRequest(Json(Message::new(&format!("Tool error: {}", error)))) .respond_to(req) } + ApiError::Io(error) => { + ApiErrorResponse::BadRequest(Json(Message::new(&format!("IO error: {}", error)))) + .respond_to(req) + } _ => ApiErrorResponse::Server(Json(Message::new("Server error!"))).respond_to(req), } } diff --git a/server/src/lib.rs b/server/src/lib.rs index 828ce99..74277ef 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -4,8 +4,8 @@ pub mod config; pub mod db; pub mod errors; pub mod provider; -pub mod provider_models; pub mod redis; +pub mod storage; pub mod stream; pub mod tools; pub mod utils; @@ -20,6 +20,7 @@ use crate::{ db::setup_db, errors::get_catchers, redis::setup_redis, + storage::setup_storage, utils::setup_encryption, web::setup_static_files, }; @@ -32,6 +33,7 @@ pub fn build_rocket() -> rocket::Rocket { .attach(setup_redis()) .attach(setup_encryption()) .attach(setup_auth("/api/auth")) + .attach(setup_storage()) .attach(setup_static_files()) .manage(reqwest::Client::new()) .register("/", get_catchers()) @@ -47,6 +49,7 @@ pub fn build_rocket() -> rocket::Rocket { "/session" => api::session_routes(&openapi_settings), "/chat" => api::chat_routes(&openapi_settings), "/tool" => api::tool_routes(&openapi_settings), + "/storage" => api::storage_routes(&openapi_settings), "/secret" => api::secret_routes(&openapi_settings), "/api_key" => api::api_key_routes(&openapi_settings), }; diff --git a/server/src/provider.rs b/server/src/provider.rs index bc4fb32..9e3c520 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -1,153 +1,27 @@ -//! LLM providers API +//! LLM providers module -pub mod anthropic; -pub mod lorem; -pub mod ollama; -pub mod openai; -mod utils; - -use std::pin::Pin; - -use dyn_clone::DynClone; -use rocket::{async_trait, futures::Stream}; -use schemars::JsonSchema; use uuid::Uuid; +mod core; +pub use core::*; +pub mod models; +pub mod providers; +mod utils; + use crate::{ - db::models::{ChatRsMessage, ChatRsProviderType, ChatRsToolCall}, - provider::{ - anthropic::AnthropicProvider, lorem::LoremProvider, ollama::OllamaProvider, - openai::OpenAIProvider, + db::{ + models::{ChatRsMessage, ChatRsMessageRole, ChatRsProviderType}, + services::FileDbService, + DbConnection, }, - provider_models::LlmModel, + errors::ApiError, + provider::{models::LlmModel, providers::*}, + storage::LocalStorage, }; pub const DEFAULT_MAX_TOKENS: u32 = 2000; pub const DEFAULT_TEMPERATURE: f32 = 0.7; -/// LLM provider-related errors -#[derive(Debug, thiserror::Error)] -pub enum LlmError { - #[error("Missing API key")] - MissingApiKey, - #[error("Provider error: {0}")] - ProviderError(String), - #[error("models.dev error: {0}")] - ModelsDevError(String), - #[error("No chat response")] - NoResponse, - #[error("Unsupported provider")] - UnsupportedProvider, - #[error("Already streaming a response for this session")] - AlreadyStreaming, - #[error("No stream found, or the stream was cancelled")] - StreamNotFound, - #[error("Missing event in stream")] - NoStreamEvent, - #[error("Client disconnected")] - ClientDisconnected, - #[error("Encryption error")] - EncryptionError, - #[error("Decryption error")] - DecryptionError, - #[error("Redis error: {0}")] - Redis(#[from] fred::error::Error), -} - -/// LLM errors during streaming -#[derive(Debug, thiserror::Error)] -pub enum LlmStreamError { - #[error("Provider error: {0}")] - ProviderError(String), - #[error("Failed to parse event: {0}")] - Parsing(#[from] serde_json::Error), - #[error("Failed to decode response: {0}")] - Decoding(#[from] tokio_util::codec::LinesCodecError), - #[error("Timeout waiting for provider response")] - StreamTimeout, - #[error("Stream was cancelled")] - StreamCancelled, - #[error("Redis error: {0}")] - Redis(#[from] fred::error::Error), -} - -/// Shared stream response type for LLM providers -pub type LlmStream = Pin + Send>>; - -/// Shared stream chunk result type for LLM providers -pub type LlmStreamChunkResult = Result; - -/// A streaming chunk of data from the LLM provider -pub enum LlmStreamChunk { - Text(String), - ToolCalls(Vec), - PendingToolCall(LlmPendingToolCall), - Usage(LlmUsage), -} - -#[derive(Debug, Clone, serde::Serialize)] -pub struct LlmPendingToolCall { - pub index: usize, - pub tool_name: String, -} - -/// Usage stats from the LLM provider -#[derive(Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] -pub struct LlmUsage { - pub input_tokens: Option, - pub output_tokens: Option, - /// Only included by OpenRouter - #[serde(skip_serializing_if = "Option::is_none")] - pub cost: Option, -} - -/// Shared configuration for LLM provider requests -#[derive(Clone, Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] -pub struct LlmProviderOptions { - pub model: String, - pub temperature: Option, - pub max_tokens: Option, -} - -/// Generic tool that can be passed to LLM providers -#[derive(Debug)] -pub struct LlmTool { - pub name: String, - pub description: String, - pub input_schema: serde_json::Value, - /// ID of the RsChat tool that this is derived from - pub tool_id: Uuid, - /// The type of tool this is derived from (internal, external API, etc.) - pub tool_type: LlmToolType, -} - -#[derive(Default, Debug, Clone, Copy, JsonSchema, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum LlmToolType { - #[default] - System, - ExternalApi, -} - -/// Unified API for LLM providers -#[async_trait] -pub trait LlmApiProvider: Send + Sync + DynClone { - /// Stream a chat response from the provider - async fn chat_stream( - &self, - messages: Vec, - tools: Option>, - options: &LlmProviderOptions, - ) -> Result; - - /// Submit a prompt to the provider (not streamed) - async fn prompt(&self, message: &str, options: &LlmProviderOptions) - -> Result; - - /// List available models from the provider - async fn list_models(&self) -> Result, LlmError>; -} - /// Build the LLM API to make calls to the provider pub fn build_llm_provider_api( provider_type: &ChatRsProviderType, @@ -175,3 +49,60 @@ pub fn build_llm_provider_api( ChatRsProviderType::Lorem => Ok(Box::new(LoremProvider::new())), } } + +/// Convert database messages to the generic messages to send to the provider implementation +pub async fn build_llm_messages( + messages: Vec, + user_id: &Uuid, + session_id: &Uuid, + db: &mut DbConnection, + storage: &LocalStorage, +) -> Result, ApiError> { + let mut llm_messages = Vec::with_capacity(messages.len()); + + for message in messages { + match message.role { + ChatRsMessageRole::User => { + let mut files: Option> = None; + if let Some(file_ids) = message.meta.user.and_then(|u| u.files) { + let mut file_db_service = FileDbService::new(db); + for file_id in file_ids { + let file = file_db_service + .find_session_file(user_id, session_id, &file_id) + .await?; + let (file_type, content) = + file.read_to_string(Some(session_id), storage).await?; + files.get_or_insert_default().push(LlmFileInput { + name: file.path, + content_type: file.content_type, + file_type, + content, + }); + } + } + llm_messages.push(LlmMessage::User(LlmUserMessage { + text: message.content, + files, + })) + } + ChatRsMessageRole::Assistant => { + llm_messages.push(LlmMessage::Assistant(LlmAssistantMessage { + text: message.content, + tool_calls: message.meta.assistant.and_then(|a| a.tool_calls), + })) + } + ChatRsMessageRole::System => llm_messages.push(LlmMessage::System(message.content)), + ChatRsMessageRole::Tool => { + if let Some(tool_call) = message.meta.tool_call { + llm_messages.push(LlmMessage::Tool(LlmToolResult { + tool_call_id: tool_call.id, + tool_name: tool_call.tool_name, + content: message.content, + })) + } + } + } + } + + Ok(llm_messages) +} diff --git a/server/src/provider/anthropic/request.rs b/server/src/provider/anthropic/request.rs deleted file mode 100644 index 43af62b..0000000 --- a/server/src/provider/anthropic/request.rs +++ /dev/null @@ -1,134 +0,0 @@ -use std::collections::HashMap; - -use serde::Serialize; - -use crate::{ - db::models::{ChatRsMessage, ChatRsMessageRole}, - provider::LlmTool, -}; - -pub fn build_anthropic_messages<'a>( - messages: &'a [ChatRsMessage], -) -> (Vec>, Option<&'a str>) { - let system_prompt = messages - .iter() - .rfind(|message| message.role == ChatRsMessageRole::System) - .map(|message| message.content.as_str()); - - let anthropic_messages: Vec = messages - .iter() - .filter_map(|message| { - let role = match message.role { - ChatRsMessageRole::User => "user", - ChatRsMessageRole::Tool => "user", - ChatRsMessageRole::Assistant => "assistant", - ChatRsMessageRole::System => return None, - }; - - let mut content_blocks = Vec::new(); - - // Handle tool result messages - if message.role == ChatRsMessageRole::Tool { - if let Some(executed_call) = &message.meta.tool_call { - content_blocks.push(AnthropicContentBlock::ToolResult { - tool_use_id: &executed_call.id, - content: &message.content, - }); - } - } else { - // Handle regular text content - if !message.content.is_empty() { - content_blocks.push(AnthropicContentBlock::Text { - text: &message.content, - }); - } - // Handle tool calls in assistant messages - if let Some(tool_calls) = message - .meta - .assistant - .as_ref() - .and_then(|a| a.tool_calls.as_ref()) - { - for tool_call in tool_calls { - content_blocks.push(AnthropicContentBlock::ToolUse { - id: &tool_call.id, - name: &tool_call.tool_name, - input: &tool_call.parameters, - }); - } - } - } - - if content_blocks.is_empty() { - return None; - } - - Some(AnthropicMessage { - role, - content: content_blocks, - }) - }) - .collect(); - - (anthropic_messages, system_prompt) -} - -pub fn build_anthropic_tools<'a>(tools: &'a [LlmTool]) -> Vec> { - tools - .iter() - .map(|tool| AnthropicTool { - name: &tool.name, - description: &tool.description, - input_schema: &tool.input_schema, - }) - .collect() -} - -/// Anthropic API request message -#[derive(Debug, Serialize)] -pub struct AnthropicMessage<'a> { - pub role: &'a str, - pub content: Vec>, -} - -/// Anthropic API request body -#[derive(Debug, Serialize)] -pub struct AnthropicRequest<'a> { - pub model: &'a str, - pub messages: Vec>, - pub max_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub system: Option<&'a str>, - #[serde(skip_serializing_if = "Option::is_none")] - pub stream: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>>, -} - -/// Anthropic tool definition -#[derive(Debug, Serialize)] -pub struct AnthropicTool<'a> { - name: &'a str, - description: &'a str, - input_schema: &'a serde_json::Value, -} - -/// Anthropic content block for messages -#[derive(Debug, Serialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum AnthropicContentBlock<'a> { - Text { - text: &'a str, - }, - ToolUse { - id: &'a str, - name: &'a str, - input: &'a HashMap, - }, - ToolResult { - tool_use_id: &'a str, - content: &'a str, - }, -} diff --git a/server/src/provider/core.rs b/server/src/provider/core.rs new file mode 100644 index 0000000..83061cd --- /dev/null +++ b/server/src/provider/core.rs @@ -0,0 +1,171 @@ +//! LLM providers - core structs and types + +use std::pin::Pin; + +use dyn_clone::DynClone; +use rocket::{async_trait, futures::Stream}; +use schemars::JsonSchema; +use uuid::Uuid; + +use crate::{ + db::models::{ChatRsFileType, ChatRsToolCall}, + provider::models::LlmModel, +}; + +/// Unified API for LLM providers +#[async_trait] +pub trait LlmApiProvider: Send + Sync + DynClone { + /// Stream a chat response from the provider + async fn chat_stream( + &self, + messages: Vec, + tools: Option>, + options: &LlmProviderOptions, + ) -> Result; + + /// Submit a prompt to the provider (not streamed) + async fn prompt(&self, message: &str, options: &LlmProviderOptions) + -> Result; + + /// List available models from the provider + async fn list_models(&self) -> Result, LlmError>; +} + +/// Stream response type for LLM providers +pub type LlmStream = Pin + Send>>; + +/// Stream chunk result type for LLM providers +pub type LlmStreamChunkResult = Result; + +/// A streaming chunk of data from the LLM provider +pub enum LlmStreamChunk { + Text(String), + ToolCalls(Vec), + PendingToolCall(LlmPendingToolCall), + Usage(LlmUsage), +} + +/// LLM provider-related errors +#[derive(Debug, thiserror::Error)] +pub enum LlmError { + #[error("Missing API key")] + MissingApiKey, + #[error("Provider error: {0}")] + ProviderError(String), + #[error("models.dev error: {0}")] + ModelsDevError(String), + #[error("No chat response")] + NoResponse, + #[error("Unsupported provider")] + UnsupportedProvider, + #[error("Already streaming a response for this session")] + AlreadyStreaming, + #[error("No stream found, or the stream was cancelled")] + StreamNotFound, + #[error("Missing event in stream")] + NoStreamEvent, + #[error("Client disconnected")] + ClientDisconnected, + #[error("Encryption error")] + EncryptionError, + #[error("Decryption error")] + DecryptionError, + #[error("Redis error: {0}")] + Redis(#[from] fred::error::Error), + #[error("File error: {0}")] + Io(#[from] std::io::Error), + #[error("Invalid file type: {0}")] + InvalidFileType(String), +} + +/// LLM errors that can occur during streaming +#[derive(Debug, thiserror::Error)] +pub enum LlmStreamError { + #[error("Provider error: {0}")] + ProviderError(String), + #[error("Failed to parse event: {0}")] + Parsing(#[from] serde_json::Error), + #[error("Failed to decode response: {0}")] + Decoding(#[from] tokio_util::codec::LinesCodecError), + #[error("Timeout waiting for provider response")] + StreamTimeout, + #[error("Stream was cancelled")] + StreamCancelled, + #[error("Redis error: {0}")] + Redis(#[from] fred::error::Error), +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct LlmPendingToolCall { + pub index: usize, + pub tool_name: String, +} + +/// Usage stats from the LLM provider +#[derive(Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] +pub struct LlmUsage { + pub input_tokens: Option, + pub output_tokens: Option, + /// Only included by OpenRouter + #[serde(skip_serializing_if = "Option::is_none")] + pub cost: Option, +} + +/// Configuration for LLM provider requests +#[derive(Clone, Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] +pub struct LlmProviderOptions { + pub model: String, + pub temperature: Option, + pub max_tokens: Option, +} + +/// Generic message type to send to LLM providers +pub enum LlmMessage { + User(LlmUserMessage), + Assistant(LlmAssistantMessage), + System(String), + Tool(LlmToolResult), +} + +pub struct LlmUserMessage { + pub text: String, + pub files: Option>, +} + +pub struct LlmFileInput { + pub name: String, + pub file_type: ChatRsFileType, + pub content_type: String, + pub content: String, +} + +pub struct LlmAssistantMessage { + pub text: String, + pub tool_calls: Option>, +} + +/// Generic tool that can be passed to LLM providers +#[derive(Debug)] +pub struct LlmTool { + pub name: String, + pub description: String, + pub input_schema: serde_json::Value, + /// ID of the RsChat tool that this is derived from + pub tool_id: Uuid, + /// The type of tool this is derived from (internal, external API, etc.) + pub tool_type: LlmToolType, +} + +#[derive(Default, Debug, Clone, Copy, JsonSchema, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum LlmToolType { + #[default] + System, + ExternalApi, +} + +pub struct LlmToolResult { + pub tool_call_id: String, + pub tool_name: String, + pub content: String, +} diff --git a/server/src/provider_models.rs b/server/src/provider/models.rs similarity index 99% rename from server/src/provider_models.rs rename to server/src/provider/models.rs index 6ed5bc8..f527ac1 100644 --- a/server/src/provider_models.rs +++ b/server/src/provider/models.rs @@ -1,3 +1,5 @@ +//! LLM model structs and utils + use std::collections::HashMap; use enum_iterator::{all, Sequence}; diff --git a/server/src/provider/openai/request.rs b/server/src/provider/openai/request.rs deleted file mode 100644 index 9feafcf..0000000 --- a/server/src/provider/openai/request.rs +++ /dev/null @@ -1,131 +0,0 @@ -use serde::Serialize; - -use crate::{ - db::models::{ChatRsMessage, ChatRsMessageRole}, - provider::LlmTool, -}; - -pub fn build_openai_messages<'a>(messages: &'a [ChatRsMessage]) -> Vec> { - messages - .iter() - .map(|message| { - let role = match message.role { - ChatRsMessageRole::User => "user", - ChatRsMessageRole::Assistant => "assistant", - ChatRsMessageRole::System => "system", - ChatRsMessageRole::Tool => "tool", - }; - let openai_message = OpenAIMessage { - role, - content: Some(&message.content), - tool_call_id: message.meta.tool_call.as_ref().map(|tc| tc.id.as_str()), - tool_calls: message - .meta - .assistant - .as_ref() - .and_then(|meta| meta.tool_calls.as_ref()) - .map(|tc| { - tc.iter() - .map(|tc| OpenAIToolCall { - id: &tc.id, - tool_type: "function", - function: OpenAIToolCallFunction { - name: &tc.tool_name, - arguments: serde_json::to_string(&tc.parameters) - .unwrap_or_default(), - }, - }) - .collect() - }), - }; - - openai_message - }) - .collect() -} - -pub fn build_openai_tools<'a>(tools: &'a [LlmTool]) -> Vec> { - tools - .iter() - .map(|tool| OpenAITool { - tool_type: "function", - function: OpenAIToolFunction { - name: &tool.name, - description: &tool.description, - parameters: &tool.input_schema, - strict: true, - }, - }) - .collect() -} - -/// OpenAI API request body -#[derive(Debug, Default, Serialize)] -pub struct OpenAIRequest<'a> { - pub model: &'a str, - pub messages: Vec>, - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub max_completion_tokens: Option, - pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub store: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stream: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stream_options: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>>, -} - -/// OpenAI API request stream options -#[derive(Debug, Serialize)] -pub struct OpenAIStreamOptions { - pub include_usage: bool, -} - -/// OpenAI API request message -#[derive(Debug, Default, Serialize)] -pub struct OpenAIMessage<'a> { - pub role: &'a str, - #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option<&'a str>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option<&'a str>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>>, -} - -/// OpenAI tool definition -#[derive(Debug, Serialize)] -pub struct OpenAITool<'a> { - #[serde(rename = "type")] - tool_type: &'a str, - function: OpenAIToolFunction<'a>, -} - -/// OpenAI tool function definition -#[derive(Debug, Serialize)] -pub struct OpenAIToolFunction<'a> { - name: &'a str, - description: &'a str, - parameters: &'a serde_json::Value, - strict: bool, -} - -/// OpenAI tool call in messages -#[derive(Debug, Serialize)] -pub struct OpenAIToolCall<'a> { - id: &'a str, - #[serde(rename = "type")] - tool_type: &'a str, - function: OpenAIToolCallFunction<'a>, -} - -/// OpenAI tool call function in messages -#[derive(Debug, Serialize)] -pub struct OpenAIToolCallFunction<'a> { - name: &'a str, - arguments: String, -} diff --git a/server/src/provider/providers.rs b/server/src/provider/providers.rs new file mode 100644 index 0000000..d558d49 --- /dev/null +++ b/server/src/provider/providers.rs @@ -0,0 +1,11 @@ +//! LLM provider implementations + +mod anthropic; +mod lorem; +mod ollama; +mod openai; + +pub use anthropic::*; +pub use lorem::*; +pub use ollama::*; +pub use openai::*; diff --git a/server/src/provider/anthropic.rs b/server/src/provider/providers/anthropic.rs similarity index 86% rename from server/src/provider/anthropic.rs rename to server/src/provider/providers/anthropic.rs index bcea95a..5be59da 100644 --- a/server/src/provider/anthropic.rs +++ b/server/src/provider/providers/anthropic.rs @@ -5,22 +5,8 @@ mod response; use rocket::{async_stream, async_trait, futures::StreamExt}; -use crate::{ - db::models::ChatRsMessage, - provider::{ - utils::get_sse_events, LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmTool, - LlmUsage, DEFAULT_MAX_TOKENS, - }, - provider_models::{LlmModel, ModelsDevService, ModelsDevServiceProvider}, -}; - -use { - request::{ - build_anthropic_messages, build_anthropic_tools, AnthropicContentBlock, AnthropicMessage, - AnthropicRequest, - }, - response::{parse_anthropic_event, AnthropicResponse, AnthropicResponseContentBlock}, -}; +use crate::provider::*; +use {request::*, response::*}; const MESSAGES_API_URL: &str = "https://api.anthropic.com/v1/messages"; const API_VERSION: &str = "2023-06-01"; @@ -51,7 +37,7 @@ impl AnthropicProvider { impl LlmApiProvider for AnthropicProvider { async fn chat_stream( &self, - messages: Vec, + messages: Vec, tools: Option>, options: &LlmProviderOptions, ) -> Result { @@ -87,7 +73,7 @@ impl LlmApiProvider for AnthropicProvider { } let stream = async_stream::stream! { - let mut sse_event_stream = get_sse_events(response); + let mut sse_event_stream = utils::get_sse_events(response); let mut tool_calls = Vec::new(); while let Some(event_result) = sse_event_stream.next().await { match event_result { @@ -163,9 +149,9 @@ impl LlmApiProvider for AnthropicProvider { } async fn list_models(&self) -> Result, LlmError> { - let models_service = ModelsDevService::new(&self.redis, &self.client); + let models_service = models::ModelsDevService::new(&self.redis, &self.client); let models = models_service - .list_models(ModelsDevServiceProvider::Anthropic) + .list_models(models::ModelsDevServiceProvider::Anthropic) .await?; Ok(models) diff --git a/server/src/provider/providers/anthropic/request.rs b/server/src/provider/providers/anthropic/request.rs new file mode 100644 index 0000000..9659861 --- /dev/null +++ b/server/src/provider/providers/anthropic/request.rs @@ -0,0 +1,170 @@ +use std::collections::HashMap; + +use serde::Serialize; + +use crate::{ + db::models::ChatRsFileType, + provider::{LlmMessage, LlmTool}, +}; + +pub fn build_anthropic_messages<'a>( + messages: &'a [LlmMessage], +) -> (Vec>, Option<&'a str>) { + let system_prompt = messages.iter().rev().find_map(|message| { + let LlmMessage::System(msg) = message else { + return None; + }; + Some(msg.as_str()) + }); + + let anthropic_messages: Vec = messages + .iter() + .filter_map(|message| { + let mut content_blocks = Vec::new(); + match message { + LlmMessage::User(user_message) => { + if !user_message.text.is_empty() { + content_blocks.push(AnthropicContentBlock::Text { + text: &user_message.text, + }); + } + if let Some(ref files) = user_message.files { + content_blocks.extend(files.iter().map(|file| match file.file_type { + ChatRsFileType::Text => AnthropicContentBlock::Document { + title: &file.name, + source: AnthropicSource::Text { + data: &file.content, + media_type: "text/plain", + }, + }, + ChatRsFileType::Image => AnthropicContentBlock::Image { + source: AnthropicSource::Base64 { + data: &file.content, + media_type: &file.content_type, + }, + }, + ChatRsFileType::Pdf => AnthropicContentBlock::Document { + title: &file.name, + source: AnthropicSource::Base64 { + data: &file.content, + media_type: "application/pdf", + }, + }, + })); + } + Some(AnthropicMessage { + role: "user", + content: content_blocks, + }) + } + LlmMessage::Assistant(assistant_message) => { + if !assistant_message.text.is_empty() { + content_blocks.push(AnthropicContentBlock::Text { + text: &assistant_message.text, + }); + } + if let Some(ref tool_calls) = assistant_message.tool_calls { + content_blocks.extend(tool_calls.iter().map(|tc| { + AnthropicContentBlock::ToolUse { + id: &tc.id, + name: &tc.tool_name, + input: &tc.parameters, + } + })); + } + Some(AnthropicMessage { + role: "assistant", + content: content_blocks, + }) + } + LlmMessage::Tool(result) => { + content_blocks.push(AnthropicContentBlock::ToolResult { + tool_use_id: &result.tool_call_id, + content: &result.content, + }); + Some(AnthropicMessage { + role: "user", + content: content_blocks, + }) + } + _ => None, + } + }) + .collect(); + + (anthropic_messages, system_prompt) +} + +pub fn build_anthropic_tools<'a>(tools: &'a [LlmTool]) -> Vec> { + tools + .iter() + .map(|tool| AnthropicTool { + name: &tool.name, + description: &tool.description, + input_schema: &tool.input_schema, + }) + .collect() +} + +/// Anthropic API request message +#[derive(Debug, Serialize)] +pub struct AnthropicMessage<'a> { + pub role: &'a str, + pub content: Vec>, +} + +/// Anthropic API request body +#[derive(Debug, Serialize)] +pub struct AnthropicRequest<'a> { + pub model: &'a str, + pub messages: Vec>, + pub max_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>>, +} + +/// Anthropic tool definition +#[derive(Debug, Serialize)] +pub struct AnthropicTool<'a> { + name: &'a str, + description: &'a str, + input_schema: &'a serde_json::Value, +} + +/// Anthropic content block for messages +#[derive(Debug, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum AnthropicContentBlock<'a> { + Text { + text: &'a str, + }, + Image { + source: AnthropicSource<'a>, + }, + Document { + title: &'a str, + source: AnthropicSource<'a>, + }, + ToolUse { + id: &'a str, + name: &'a str, + input: &'a HashMap, + }, + ToolResult { + tool_use_id: &'a str, + content: &'a str, + }, +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum AnthropicSource<'a> { + Base64 { data: &'a str, media_type: &'a str }, + Text { data: &'a str, media_type: &'a str }, +} diff --git a/server/src/provider/anthropic/response.rs b/server/src/provider/providers/anthropic/response.rs similarity index 100% rename from server/src/provider/anthropic/response.rs rename to server/src/provider/providers/anthropic/response.rs diff --git a/server/src/provider/lorem.rs b/server/src/provider/providers/lorem.rs similarity index 83% rename from server/src/provider/lorem.rs rename to server/src/provider/providers/lorem.rs index 4c13aa3..d93535a 100644 --- a/server/src/provider/lorem.rs +++ b/server/src/provider/providers/lorem.rs @@ -4,34 +4,19 @@ use std::pin::Pin; use std::time::Duration; use rocket::futures::Stream; -use rocket_okapi::JsonSchema; use tokio::time::{interval, Interval}; -use crate::{ - db::models::ChatRsMessage, - provider::{ - LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmStreamChunk, - LlmStreamChunkResult, LlmStreamError, LlmTool, - }, - provider_models::LlmModel, -}; +use crate::provider::*; /// A test/dummy provider that streams 'lorem ipsum...' and emits test errors during the stream #[derive(Debug, Clone)] pub struct LoremProvider { - pub config: LoremConfig, -} - -#[derive(Debug, Clone, JsonSchema)] -pub struct LoremConfig { pub interval: u32, } impl LoremProvider { pub fn new() -> Self { - LoremProvider { - config: LoremConfig { interval: 400 }, - } + LoremProvider { interval: 400 } } } @@ -72,7 +57,7 @@ impl Stream for LoremStream { impl LlmApiProvider for LoremProvider { async fn chat_stream( &self, - _messages: Vec, + _messages: Vec, _tools: Option>, _options: &LlmProviderOptions, ) -> Result { @@ -108,7 +93,7 @@ impl LlmApiProvider for LoremProvider { let stream: LlmStream = Box::pin(LoremStream { words: lorem_words, index: 0, - interval: interval(Duration::from_millis(self.config.interval.into())), + interval: interval(Duration::from_millis(self.interval.into())), }); tokio::time::sleep(Duration::from_millis(1000)).await; diff --git a/server/src/provider/ollama.rs b/server/src/provider/providers/ollama.rs similarity index 89% rename from server/src/provider/ollama.rs rename to server/src/provider/providers/ollama.rs index 68ef215..6d2e4a5 100644 --- a/server/src/provider/ollama.rs +++ b/server/src/provider/providers/ollama.rs @@ -5,24 +5,8 @@ mod response; use rocket::{async_stream, async_trait, futures::StreamExt}; -use crate::{ - db::models::ChatRsMessage, - provider::{ - utils::get_json_events, LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, - LlmStreamChunk, LlmTool, LlmUsage, - }, - provider_models::LlmModel, -}; - -use { - request::{ - build_ollama_messages, build_ollama_tools, OllamaChatRequest, OllamaCompletionRequest, - OllamaOptions, - }, - response::{ - parse_ollama_event, OllamaCompletionResponse, OllamaModelsResponse, OllamaToolCall, - }, -}; +use crate::provider::*; +use {request::*, response::*}; const CHAT_API_URL: &str = "/api/chat"; const COMPLETION_API_URL: &str = "/api/generate"; @@ -48,7 +32,7 @@ impl OllamaProvider { impl LlmApiProvider for OllamaProvider { async fn chat_stream( &self, - messages: Vec, + messages: Vec, tools: Option>, options: &LlmProviderOptions, ) -> Result { @@ -85,8 +69,8 @@ impl LlmApiProvider for OllamaProvider { } let stream = async_stream::stream! { - let mut json_stream = get_json_events(response); - let mut tool_calls: Vec = Vec::new(); + let mut json_stream = utils::get_json_events(response); + let mut tool_calls: Vec = Vec::new(); while let Some(event) = json_stream.next().await { match event { Ok(event) => { diff --git a/server/src/provider/ollama/request.rs b/server/src/provider/providers/ollama/request.rs similarity index 62% rename from server/src/provider/ollama/request.rs rename to server/src/provider/providers/ollama/request.rs index 520b41f..c685b39 100644 --- a/server/src/provider/ollama/request.rs +++ b/server/src/provider/providers/ollama/request.rs @@ -3,59 +3,63 @@ use serde::Serialize; use crate::{ - db::models::{ChatRsMessage, ChatRsMessageRole}, - provider::LlmTool, + db::models::ChatRsFileType, + provider::{LlmMessage, LlmTool}, tools::ToolParameters, }; -/// Convert ChatRsMessages to Ollama messages -pub fn build_ollama_messages(messages: &[ChatRsMessage]) -> Vec { +/// Convert LlmMessages to Ollama messages +pub fn build_ollama_messages(messages: &[LlmMessage]) -> Vec { messages .iter() - .map(|msg| { - let role = match msg.role { - ChatRsMessageRole::User => "user", - ChatRsMessageRole::Assistant => "assistant", - ChatRsMessageRole::System => "system", - ChatRsMessageRole::Tool => "tool", - }; - - let mut ollama_msg = OllamaMessage { - role, - content: &msg.content, - tool_calls: None, - tool_name: None, - }; - - // Handle tool calls in assistant messages - if msg.role == ChatRsMessageRole::Assistant { - if let Some(msg_tool_calls) = msg - .meta - .assistant - .as_ref() - .and_then(|m| m.tool_calls.as_ref()) - { - let tool_calls = msg_tool_calls + .map(|message| match message { + LlmMessage::User(user_message) => { + let images = user_message.files.as_ref().map(|files| { + files + .iter() + .filter_map(|file| match file.file_type { + ChatRsFileType::Image => Some(file.content.as_str()), + _ => None, + }) + .collect::>() + }); + OllamaMessage { + role: "user", + content: &user_message.text, + images, + ..Default::default() + } + } + LlmMessage::Assistant(assistant_message) => { + let tool_calls = assistant_message.tool_calls.as_ref().map(|tool_calls| { + tool_calls .iter() .map(|tc| OllamaToolCall { - function: OllamaToolFunction { + function: OllamaFunction { name: &tc.tool_name, arguments: &tc.parameters, }, }) - .collect(); - ollama_msg.tool_calls = Some(tool_calls); + .collect() + }); + OllamaMessage { + role: "assistant", + content: &assistant_message.text, + tool_calls, + ..Default::default() } } - - // Handle tool messages (results from tool calls) - if msg.role == ChatRsMessageRole::Tool { - if let Some(ref tool_call) = msg.meta.tool_call { - ollama_msg.tool_name = Some(&tool_call.tool_name); - } - } - - ollama_msg + LlmMessage::System(text) => OllamaMessage { + role: "system", + content: text, + ..Default::default() + }, + LlmMessage::Tool(result) => OllamaMessage { + role: "tool", + content: &result.content, + tool_name: Some(&result.tool_name), + ..Default::default() + }, }) .collect() } @@ -100,11 +104,13 @@ pub struct OllamaCompletionRequest<'a> { } /// Ollama chat message -#[derive(Debug, Serialize)] +#[derive(Debug, Default, Serialize)] pub struct OllamaMessage<'a> { pub role: &'a str, pub content: &'a str, #[serde(skip_serializing_if = "Option::is_none")] + pub images: Option>, + #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>>, #[serde(skip_serializing_if = "Option::is_none")] pub tool_name: Option<&'a str>, @@ -113,12 +119,12 @@ pub struct OllamaMessage<'a> { /// Ollama tool call in a message #[derive(Debug, Serialize)] pub struct OllamaToolCall<'a> { - pub function: OllamaToolFunction<'a>, + pub function: OllamaFunction<'a>, } /// Ollama tool function #[derive(Debug, Serialize)] -pub struct OllamaToolFunction<'a> { +pub struct OllamaFunction<'a> { pub name: &'a str, pub arguments: &'a ToolParameters, } diff --git a/server/src/provider/ollama/response.rs b/server/src/provider/providers/ollama/response.rs similarity index 90% rename from server/src/provider/ollama/response.rs rename to server/src/provider/providers/ollama/response.rs index 285b61e..b7c6610 100644 --- a/server/src/provider/ollama/response.rs +++ b/server/src/provider/providers/ollama/response.rs @@ -9,8 +9,8 @@ use crate::{ /// Parse Ollama streaming event into LlmStreamChunks, and track tool calls pub fn parse_ollama_event( - event: OllamaStreamResponse, - tool_calls: &mut Vec, + event: OllamaStreamEvent, + tool_calls: &mut Vec, ) -> Vec> { let mut chunks = Vec::with_capacity(1); // Handle final message with usage stats @@ -42,10 +42,10 @@ pub fn parse_ollama_event( /// Ollama chat response (streaming) #[derive(Debug, Deserialize)] -pub struct OllamaStreamResponse { +pub struct OllamaStreamEvent { pub model: String, pub created_at: String, - pub message: OllamaMessage, + pub message: OllamaMessageResponse, pub done: bool, #[serde(default)] pub done_reason: Option, @@ -88,28 +88,28 @@ pub struct OllamaCompletionResponse { /// Ollama message in response #[derive(Debug, Deserialize)] -pub struct OllamaMessage { +pub struct OllamaMessageResponse { pub role: String, #[serde(default)] pub content: String, #[serde(default)] - pub tool_calls: Vec, + pub tool_calls: Vec, } /// Ollama tool call in response #[derive(Debug, Deserialize)] -pub struct OllamaToolCall { - pub function: OllamaToolFunction, +pub struct OllamaToolCallResponse { + pub function: OllamaFunctionResponse, } /// Ollama tool function in response #[derive(Debug, Deserialize)] -pub struct OllamaToolFunction { +pub struct OllamaFunctionResponse { pub name: String, pub arguments: serde_json::Value, } -impl OllamaToolFunction { +impl OllamaFunctionResponse { /// Convert to ChatRsToolCall if the tool exists in the provided tools pub fn convert(self, tools: &[LlmTool]) -> Option { let tool = tools.iter().find(|t| t.name == self.name)?; @@ -140,8 +140,8 @@ impl From<&OllamaCompletionResponse> for Option { } } -impl From<&OllamaStreamResponse> for Option { - fn from(response: &OllamaStreamResponse) -> Self { +impl From<&OllamaStreamEvent> for Option { + fn from(response: &OllamaStreamEvent) -> Self { if response.prompt_eval_count.is_some() || response.eval_count.is_some() { Some(LlmUsage { input_tokens: response.prompt_eval_count, diff --git a/server/src/provider/openai.rs b/server/src/provider/providers/openai.rs similarity index 86% rename from server/src/provider/openai.rs rename to server/src/provider/providers/openai.rs index 5153d45..edf2c06 100644 --- a/server/src/provider/openai.rs +++ b/server/src/provider/providers/openai.rs @@ -5,22 +5,8 @@ mod response; use rocket::{async_stream, async_trait, futures::StreamExt}; -use crate::{ - db::models::ChatRsMessage, - provider::{ - utils::get_sse_events, LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, - LlmStreamChunk, LlmTool, LlmUsage, - }, - provider_models::{LlmModel, ModelsDevService, ModelsDevServiceProvider}, -}; - -use { - request::{ - build_openai_messages, build_openai_tools, OpenAIMessage, OpenAIRequest, - OpenAIStreamOptions, - }, - response::{parse_openai_event, OpenAIResponse, OpenAIStreamToolCall}, -}; +use crate::provider::*; +use {request::*, response::*}; const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1"; const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1"; @@ -54,7 +40,7 @@ impl OpenAIProvider { impl LlmApiProvider for OpenAIProvider { async fn chat_stream( &self, - messages: Vec, + messages: Vec, tools: Option>, options: &LlmProviderOptions, ) -> Result { @@ -99,7 +85,7 @@ impl LlmApiProvider for OpenAIProvider { } let stream = async_stream::stream! { - let mut sse_event_stream = get_sse_events(response); + let mut sse_event_stream =utils:: get_sse_events(response); let mut tool_calls: Vec = Vec::new(); while let Some(event) = sse_event_stream.next().await { match event { @@ -134,11 +120,12 @@ impl LlmApiProvider for OpenAIProvider { model: &options.model, messages: vec![OpenAIMessage { role: "user", - content: Some(message), + content: Some(vec![OpenAIContent::Text { text: message }]), ..Default::default() }], max_tokens: options.max_tokens, temperature: options.temperature, + store: (self.base_url == OPENAI_API_BASE_URL).then_some(false), ..Default::default() }; @@ -182,12 +169,11 @@ impl LlmApiProvider for OpenAIProvider { } async fn list_models(&self) -> Result, LlmError> { - let models_service = ModelsDevService::new(&self.redis, &self.client); - let models = models_service + let models = models::ModelsDevService::new(&self.redis, &self.client) .list_models({ match self.base_url.as_str() { - OPENROUTER_API_BASE_URL => ModelsDevServiceProvider::OpenRouter, - _ => ModelsDevServiceProvider::OpenAI, + OPENROUTER_API_BASE_URL => models::ModelsDevServiceProvider::OpenRouter, + _ => models::ModelsDevServiceProvider::OpenAI, } }) .await?; diff --git a/server/src/provider/providers/openai/request.rs b/server/src/provider/providers/openai/request.rs new file mode 100644 index 0000000..87457d3 --- /dev/null +++ b/server/src/provider/providers/openai/request.rs @@ -0,0 +1,188 @@ +use serde::Serialize; + +use crate::{ + db::models::ChatRsFileType, + provider::{utils::create_data_uri, LlmMessage, LlmTool}, +}; + +pub fn build_openai_messages<'a>(messages: &'a [LlmMessage]) -> Vec> { + messages + .iter() + .map(|message| match message { + LlmMessage::User(user_message) => { + let mut content = Vec::new(); + if !user_message.text.is_empty() { + content.push(OpenAIContent::Text { + text: &user_message.text, + }); + } + if let Some(ref files) = user_message.files { + content.extend(files.iter().map(|file| match file.file_type { + ChatRsFileType::Text => OpenAIContent::Text { + text: &file.content, + }, + ChatRsFileType::Image => OpenAIContent::ImageUrl { + image_url: OpenAIImageUrl { + url: create_data_uri(&file.content_type, &file.content), + }, + }, + ChatRsFileType::Pdf => OpenAIContent::File { + file: OpenAIFile { + file_data: create_data_uri(&file.content_type, &file.content), + filename: &file.name, + }, + }, + })); + } + OpenAIMessage { + role: "user", + content: Some(content), + ..Default::default() + } + } + LlmMessage::Assistant(assistant_message) => { + let tool_calls = assistant_message.tool_calls.as_ref().map(|tc| { + tc.iter() + .map(|tc| OpenAIToolCall { + id: &tc.id, + tool_type: "function", + function: OpenAIToolCallFunction { + name: &tc.tool_name, + arguments: serde_json::to_string(&tc.parameters) + .unwrap_or_default(), + }, + }) + .collect() + }); + OpenAIMessage { + role: "assistant", + content: (!assistant_message.text.is_empty()).then(|| { + vec![OpenAIContent::Text { + text: &assistant_message.text, + }] + }), + tool_calls, + ..Default::default() + } + } + LlmMessage::System(text) => OpenAIMessage { + role: "system", + content: Some(vec![OpenAIContent::Text { text }]), + ..Default::default() + }, + LlmMessage::Tool(tool_result) => OpenAIMessage { + role: "tool", + content: Some(vec![OpenAIContent::Text { + text: &tool_result.content, + }]), + tool_call_id: Some(&tool_result.tool_call_id), + ..Default::default() + }, + }) + .collect() +} + +pub fn build_openai_tools<'a>(tools: &'a [LlmTool]) -> Vec> { + tools + .iter() + .map(|tool| OpenAITool { + tool_type: "function", + function: OpenAIToolFunction { + name: &tool.name, + description: &tool.description, + parameters: &tool.input_schema, + strict: true, + }, + }) + .collect() +} + +/// OpenAI API request body +#[derive(Debug, Default, Serialize)] +pub struct OpenAIRequest<'a> { + pub model: &'a str, + pub messages: Vec>, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>>, +} + +/// OpenAI API request stream options +#[derive(Debug, Serialize)] +pub struct OpenAIStreamOptions { + pub include_usage: bool, +} + +/// OpenAI API request message +#[derive(Debug, Default, Serialize)] +pub struct OpenAIMessage<'a> { + pub role: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>>, +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum OpenAIContent<'a> { + Text { text: &'a str }, + ImageUrl { image_url: OpenAIImageUrl }, + File { file: OpenAIFile<'a> }, +} + +#[derive(Debug, Serialize)] +pub struct OpenAIImageUrl { + url: String, +} + +#[derive(Debug, Serialize)] +pub struct OpenAIFile<'a> { + file_data: String, + filename: &'a str, +} + +/// OpenAI tool definition +#[derive(Debug, Serialize)] +pub struct OpenAITool<'a> { + #[serde(rename = "type")] + tool_type: &'a str, + function: OpenAIToolFunction<'a>, +} + +/// OpenAI tool function definition +#[derive(Debug, Serialize)] +pub struct OpenAIToolFunction<'a> { + name: &'a str, + description: &'a str, + strict: bool, + parameters: &'a serde_json::Value, +} + +/// OpenAI tool call in messages +#[derive(Debug, Serialize)] +pub struct OpenAIToolCall<'a> { + id: &'a str, + #[serde(rename = "type")] + tool_type: &'a str, + function: OpenAIToolCallFunction<'a>, +} + +/// OpenAI tool call function in messages +#[derive(Debug, Serialize)] +pub struct OpenAIToolCallFunction<'a> { + name: &'a str, + arguments: String, +} diff --git a/server/src/provider/openai/response.rs b/server/src/provider/providers/openai/response.rs similarity index 100% rename from server/src/provider/openai/response.rs rename to server/src/provider/providers/openai/response.rs diff --git a/server/src/provider/utils.rs b/server/src/provider/utils.rs index 1845ab1..713a66c 100644 --- a/server/src/provider/utils.rs +++ b/server/src/provider/utils.rs @@ -1,3 +1,5 @@ +//! Utils for working with LLM requests and responses + use rocket::futures::TryStreamExt; use serde::de::DeserializeOwned; use tokio_stream::{Stream, StreamExt}; @@ -8,6 +10,11 @@ use tokio_util::{ use crate::provider::LlmStreamError; +/// Create a data URI +pub fn create_data_uri(content_type: &str, b64_string: &str) -> String { + format!("data:{};base64,{}", content_type, b64_string) +} + /// Get a stream of deserialized events from a provider SSE stream. pub fn get_sse_events( response: reqwest::Response, diff --git a/server/src/storage.rs b/server/src/storage.rs new file mode 100644 index 0000000..5d90cc1 --- /dev/null +++ b/server/src/storage.rs @@ -0,0 +1,55 @@ +mod data_guard; +mod local; + +use std::path::{Path, PathBuf}; + +use rocket::fairing::AdHoc; +use uuid::Uuid; + +use crate::{ + config::get_app_config, + db::models::{ChatRsFile, ChatRsFileType}, + provider::LlmError, +}; +pub use data_guard::*; +pub use local::*; + +/// Default data directory path. +pub const DEFAULT_DATA_DIR: &str = "/data"; + +/// Setup file reading and writing for the Rocket application. +pub fn setup_storage() -> AdHoc { + AdHoc::on_ignite("Storage", |rocket| async { + let app_config = get_app_config(&rocket); + let data_dir = app_config.data_dir.as_deref().unwrap_or(DEFAULT_DATA_DIR); + let storage_path = PathBuf::from(data_dir).join("storage"); + let storage = LocalStorage::new(storage_path); + + rocket.manage(storage) + }) +} + +impl ChatRsFile { + /// Get the file type and contents for LLM input. Uses base64 URLs for image and PDF files. + pub async fn read_to_string( + &self, + session_id: Option<&Uuid>, + storage: &LocalStorage, + ) -> Result<(ChatRsFileType, String), LlmError> { + let file_type = ChatRsFileType::try_from(self.file_type.as_str())?; + let content: String = match file_type { + ChatRsFileType::Text => { + let bytes = storage + .read_file_as_bytes(&self.user_id, session_id, Path::new(&self.path)) + .await?; + String::from_utf8_lossy(&bytes).into() + } + ChatRsFileType::Image | ChatRsFileType::Pdf => { + storage + .read_file_as_base64(&self.user_id, session_id, Path::new(&self.path)) + .await? + } + }; + Ok((file_type, content)) + } +} diff --git a/server/src/storage/data_guard.rs b/server/src/storage/data_guard.rs new file mode 100644 index 0000000..25e4cce --- /dev/null +++ b/server/src/storage/data_guard.rs @@ -0,0 +1,95 @@ +use rocket::{ + async_trait, + data::{FromData, Outcome, ToByteUnit}, + http::{ContentType, Status}, + outcome::{try_outcome, IntoOutcome}, + Request, +}; +use rocket_okapi::request::OpenApiFromData; + +use crate::db::models::ChatRsFileType; + +const MAX_FILE_SIZE: usize = 4 * 1024 * 1024; // 4 MB + +/// Data guard for file uploads +pub struct FileData<'r> { + pub data: rocket::data::DataStream<'r>, + pub content_type: &'r ContentType, + pub file_type: ChatRsFileType, + pub content_length: usize, +} + +#[async_trait] +impl<'r> FromData<'r> for FileData<'r> { + type Error = &'static str; + + async fn from_data( + req: &'r Request<'_>, + mut data: rocket::Data<'r>, + ) -> Outcome<'r, Self, Self::Error> { + if data.peek(8).await.is_empty() { + return Outcome::Error((Status::BadRequest, "No data found")); + } + let content_type = try_outcome!(req + .content_type() + .or_error((Status::BadRequest, "No content type found"))); + let content_length: usize = try_outcome!(req + .headers() + .get_one("Content-Length") + .map(|s| s.parse().unwrap_or(0)) + .or_error((Status::LengthRequired, "No content length found"))); + if content_length > MAX_FILE_SIZE { + return Outcome::Error((Status::PayloadTooLarge, "File size exceeds maximum")); + } + + let file_type = { + if content_type.is_jpeg() + || content_type.is_png() + || content_type.is_webp() + || content_type.is_gif() + { + ChatRsFileType::Image + } else if content_type.is_pdf() { + ChatRsFileType::Pdf + } else { + ChatRsFileType::Text + } + }; + + Outcome::Success(FileData { + data: data.open(5.mebibytes()), + file_type, + content_length, + content_type, + }) + } +} + +impl<'r> OpenApiFromData<'r> for FileData<'r> { + fn request_body( + _gen: &mut rocket_okapi::r#gen::OpenApiGenerator, + ) -> rocket_okapi::Result { + Ok(rocket_okapi::okapi::openapi3::RequestBody { + description: Some("File data".to_string()), + content: { + let mut content = schemars::Map::new(); + content.insert( + "application/octet-stream".into(), + rocket_okapi::okapi::openapi3::MediaType { + schema: Some(rocket_okapi::okapi::openapi3::SchemaObject { + instance_type: Some(schemars::schema::SingleOrVec::Single(Box::new( + schemars::schema::InstanceType::String, + ))), + format: Some("binary".to_string()), + ..Default::default() + }), + ..Default::default() + }, + ); + content + }, + required: true, + ..Default::default() + }) + } +} diff --git a/server/src/storage/local.rs b/server/src/storage/local.rs new file mode 100644 index 0000000..4c0cd9f --- /dev/null +++ b/server/src/storage/local.rs @@ -0,0 +1,134 @@ +use std::{ + io::Result as IoResult, + path::{Path, PathBuf}, +}; +use tokio::{ + fs::File, + io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}, +}; +use uuid::Uuid; + +pub struct LocalStorage { + base_path: PathBuf, +} + +impl LocalStorage { + pub fn new(base_path: PathBuf) -> Self { + LocalStorage { base_path } + } + + pub async fn read_file_as_bytes( + &self, + user_id: &Uuid, + session_id: Option<&Uuid>, + path: &Path, + ) -> IoResult> { + let path = self.get_file_path(user_id, session_id, path)?; + let mut file = File::open(path).await?; + let metadata = file.metadata().await?; + let mut file_reader = BufReader::new(&mut file); + + let mut buffer = Vec::with_capacity(metadata.len() as usize); + file_reader.read_to_end(&mut buffer).await?; + Ok(buffer) + } + + pub async fn read_file_as_base64( + &self, + user_id: &Uuid, + session_id: Option<&Uuid>, + path: &Path, + ) -> IoResult { + let path = self.get_file_path(user_id, session_id, path)?; + tokio::task::spawn_blocking(move || read_base64(&path)).await? + } + + pub async fn create_file( + &self, + user_id: &Uuid, + session_id: Option<&Uuid>, + path: &Path, + mut data: impl AsyncRead + Unpin, + ) -> IoResult { + let file_path = self.get_file_path(user_id, session_id, path)?; + let dir = file_path.parent().expect("Should have a parent directory"); + tokio::fs::create_dir_all(&dir).await?; + + let mut file = File::create_new(&file_path).await?; + let mut file_writer = BufWriter::new(&mut file); + let mut read_buffer = [0; 4096]; + let mut total_bytes_written: usize = 0; + while let Ok(n) = data.read(&mut read_buffer).await { + if n == 0 { + break; + } + file_writer.write_all(&read_buffer[..n]).await?; + total_bytes_written += n; + } + + file_writer.flush().await?; + file.sync_all().await?; + + Ok(total_bytes_written) + } + + pub async fn delete_file>( + &self, + user_id: &Uuid, + session_id: Option<&Uuid>, + path: P, + ) -> IoResult<()> { + let file_path = self.get_file_path(user_id, session_id, path)?; + tokio::fs::remove_file(&file_path).await + } + + fn get_user_directory(&self, user_id: &Uuid, session_id: Option<&Uuid>) -> PathBuf { + let mut dir = self.base_path.join(user_id.to_string()); + match session_id { + Some(session_id) => { + dir.push("sessions"); + dir.push(session_id.to_string()); + dir + } + None => { + dir.push("files"); + dir + } + } + } + + pub fn get_file_path>( + &self, + user_id: &Uuid, + session_id: Option<&Uuid>, + path: P, + ) -> IoResult { + if !path.as_ref().is_relative() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Path must be relative", + )); + } + Ok(self.get_user_directory(user_id, session_id).join(path)) + } +} + +/// Synchronously read a file as a base64 encoded string. +/// (This is synchronous because the `base64` crate is synchronous.) +fn read_base64(path: &Path) -> IoResult { + let mut file = std::fs::File::open(path)?; + let file_size = file.metadata()?.len(); + let estimated_size = (file_size + 2) / 3 * 4; + let mut file_reader = std::io::BufReader::new(&mut file); + + let mut result = Vec::with_capacity(estimated_size as usize); + { + let mut encoder = base64::write::EncoderWriter::new( + &mut result, + &base64::engine::general_purpose::STANDARD, + ); + std::io::copy(&mut file_reader, &mut encoder)?; + encoder.finish()?; + } + Ok(String::from_utf8(result).expect("base64 is valid UTF8")) +} diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 0c3f499..7f2b99c 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -304,7 +304,7 @@ impl LlmStreamWriter { mod tests { use super::*; use crate::{ - provider::{lorem::LoremProvider, LlmApiProvider, LlmProviderOptions}, + provider::{providers::LoremProvider, LlmApiProvider, LlmProviderOptions}, redis::{ExclusiveClientManager, ExclusiveClientPool}, stream::{cancel_current_chat_stream, check_chat_stream_exists}, }; diff --git a/server/src/tools.rs b/server/src/tools.rs index 7d81cab..32bd947 100644 --- a/server/src/tools.rs +++ b/server/src/tools.rs @@ -18,31 +18,35 @@ use { /// User configuration of tools when sending a chat message #[derive(Debug, Default, PartialEq, JsonSchema, serde::Serialize, serde::Deserialize)] pub struct SendChatToolInput { + #[serde(skip_serializing_if = "Option::is_none")] pub system: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub external_apis: Option>, } -/// Get all tools from the user's input in LLM generic format -pub async fn get_llm_tools_from_input( - user_id: &Uuid, - input: &SendChatToolInput, - tool_db_service: &mut ToolDbService<'_>, -) -> Result, ApiError> { - let mut llm_tools = Vec::with_capacity(5); - if let Some(ref system_tool_input) = input.system { - let system_tools = tool_db_service.find_system_tools_by_user(&user_id).await?; - let system_llm_tools = system_tool_input.get_llm_tools(&system_tools)?; - llm_tools.extend(system_llm_tools); - } - if let Some(ref external_apis_input) = input.external_apis { - let external_api_tools = tool_db_service - .find_external_api_tools_by_user(&user_id) - .await?; - for tool_input in external_apis_input { - let api_llm_tools = tool_input.into_llm_tools(&external_api_tools)?; - llm_tools.extend(api_llm_tools); +impl SendChatToolInput { + /// Get all tools from the user's input in LLM generic format + pub async fn get_llm_tools( + &self, + user_id: &Uuid, + tool_db_service: &mut ToolDbService<'_>, + ) -> Result, ApiError> { + let mut llm_tools = Vec::with_capacity(5); + if let Some(ref system_tool_input) = self.system { + let system_tools = tool_db_service.find_system_tools_by_user(&user_id).await?; + let system_llm_tools = system_tool_input.get_llm_tools(&system_tools)?; + llm_tools.extend(system_llm_tools); + } + if let Some(ref external_apis_input) = self.external_apis { + let external_api_tools = tool_db_service + .find_external_api_tools_by_user(&user_id) + .await?; + for tool_input in external_apis_input { + let api_llm_tools = tool_input.into_llm_tools(&external_api_tools)?; + llm_tools.extend(api_llm_tools); + } } - } - Ok(llm_tools) + Ok(llm_tools) + } } diff --git a/server/src/tools/core.rs b/server/src/tools/core.rs index e1c2fcf..8cacf92 100644 --- a/server/src/tools/core.rs +++ b/server/src/tools/core.rs @@ -62,8 +62,10 @@ pub enum ToolError { ToolExecutionError(String), #[error("Tool execution cancelled: {0}")] Cancelled(String), - #[error("IO error: {0}")] + #[error("File error: {0}")] Io(#[from] std::io::Error), + #[error("Database error: {0}")] + Database(#[from] diesel::result::Error), } /// JSON schema for tool input parameters diff --git a/server/src/tools/external_api/web_search.rs b/server/src/tools/external_api/web_search.rs index 3e21167..8688f9c 100644 --- a/server/src/tools/external_api/web_search.rs +++ b/server/src/tools/external_api/web_search.rs @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; use crate::{ provider::{LlmTool, LlmToolType}, + tools::utils::get_json_schema, utils::SenderWithLogging, }; @@ -37,12 +38,10 @@ struct ContentInputSchema { url: String, } -static WEB_SEARCH_INPUT_SCHEMA: LazyLock = LazyLock::new(|| { - serde_json::to_value(schema_for!(QueryInputSchema)).expect("Should be valid JSON") -}); -static EXTRACT_INPUT_SCHEMA: LazyLock = LazyLock::new(|| { - serde_json::to_value(schema_for!(ContentInputSchema)).expect("Should be valid JSON") -}); +static WEB_SEARCH_INPUT_SCHEMA: LazyLock = + LazyLock::new(|| get_json_schema::()); +static EXTRACT_INPUT_SCHEMA: LazyLock = + LazyLock::new(|| get_json_schema::()); /// A web search tool that can support multiple providers. pub struct WebSearchTool { diff --git a/server/src/tools/system.rs b/server/src/tools/system.rs index 7b0a088..b197666 100644 --- a/server/src/tools/system.rs +++ b/server/src/tools/system.rs @@ -1,4 +1,5 @@ mod code_runner; +mod files; mod system_info; use diesel_as_jsonb::AsJsonb; @@ -7,16 +8,21 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::{db::models::ChatRsSystemTool, provider::LlmTool, utils::SenderWithLogging}; +use crate::{ + config::AppConfig, + db::{models::ChatRsSystemTool, DbConnection}, + provider::LlmTool, + utils::SenderWithLogging, +}; -use super::{ToolError, ToolLog, ToolParameters, ToolResponseFormat, ToolResult}; +use super::*; /// System tool configuration saved in the database #[derive(Debug, Serialize, Deserialize, JsonSchema, AsJsonb)] #[serde(tag = "type", content = "config", rename_all = "snake_case")] pub enum ChatRsSystemToolConfig { CodeRunner(code_runner::CodeRunnerConfig), - Files(()), + Files(files::FilesConfig), SystemInfo, } impl ChatRsSystemToolConfig { @@ -24,7 +30,7 @@ impl ChatRsSystemToolConfig { pub fn validate(&self) -> ToolResult<()> { match self { ChatRsSystemToolConfig::CodeRunner(config) => config.validate(), - ChatRsSystemToolConfig::Files(_) => Ok(()), + ChatRsSystemToolConfig::Files(config) => config.validate(), ChatRsSystemToolConfig::SystemInfo => Ok(()), } } @@ -39,7 +45,8 @@ pub struct SystemToolInput { /// Enable/disable tools to get system information, current date/time, etc. #[serde(default)] info: bool, - // TODO files, etc... + #[serde(default, skip_serializing_if = "Option::is_none")] + files: Option, } impl SystemToolInput { /// Get all the LLM tools given the user's input @@ -67,6 +74,16 @@ impl SystemToolInput { .ok_or(ToolError::ToolNotFound)?; llm_tools.extend(config.get_llm_tools(tool_id, None)); } + if let Some(input) = &self.files { + let (config, tool_id) = system_tools + .iter() + .find_map(|t| match &t.data { + ChatRsSystemToolConfig::Files(config) => Some((config, t.id)), + _ => None, + }) + .ok_or(ToolError::ToolNotFound)?; + llm_tools.extend(config.get_llm_tools(tool_id, Some(input))); + } Ok(llm_tools) } } @@ -74,26 +91,25 @@ impl SystemToolInput { /// Trait for all system tools which allows validating input parameters and executing the tool. #[async_trait] pub trait SystemTool: Send + Sync { - fn input_schema(&self, tool_name: &str) -> &serde_json::Value; + fn input_schema(&self, tool_name: &str) -> ToolResult<&serde_json::Value>; async fn execute( - &self, + &mut self, tool_name: &str, - parameters: &ToolParameters, + parameters: serde_json::Value, sender: &SenderWithLogging, ) -> ToolResult<(String, ToolResponseFormat)>; async fn validate_and_execute( - &self, + &mut self, tool_name: &str, parameters: &ToolParameters, tx: &SenderWithLogging, ) -> ToolResult<(String, ToolResponseFormat)> { - jsonschema::validate( - self.input_schema(tool_name), - &serde_json::to_value(parameters)?, - ) - .map_err(|err| ToolError::InvalidParameters(err.to_string()))?; - self.execute(tool_name, parameters, tx).await + let params_value = serde_json::to_value(parameters)?; + jsonschema::validate(self.input_schema(tool_name)?, ¶ms_value) + .map_err(|err| ToolError::InvalidParameters(err.to_string()))?; + + self.execute(tool_name, params_value, tx).await } } @@ -108,22 +124,35 @@ trait SystemToolConfig { fn get_llm_tools( &self, tool_id: Uuid, - input_config: Option, + input_config: Option<&Self::DynamicConfig>, ) -> Vec; /// Validates the configuration of the system tool. fn validate(&self) -> ToolResult<()>; } -impl ChatRsSystemTool { +impl<'a> ChatRsSystemTool { /// Create the system tool executor from the database entity - pub fn build_executor(&self) -> Box { + pub fn build_executor( + &'a self, + db: &'a mut DbConnection, + app_config: &'a AppConfig, + session_id: &'a Uuid, + ) -> Box { match &self.data { ChatRsSystemToolConfig::CodeRunner(config) => { Box::new(code_runner::CodeRunner::new(config)) } - ChatRsSystemToolConfig::SystemInfo => Box::new(system_info::SystemInfo::new()), - ChatRsSystemToolConfig::Files(_) => unimplemented!(), + ChatRsSystemToolConfig::SystemInfo => { + Box::new(system_info::SystemInfo::new(app_config)) + } + ChatRsSystemToolConfig::Files(config) => Box::new(files::Files::new( + &self.user_id, + session_id, + app_config, + db, + config, + )), } } } diff --git a/server/src/tools/system/code_runner.rs b/server/src/tools/system/code_runner.rs index a74c6c7..cdd245a 100644 --- a/server/src/tools/system/code_runner.rs +++ b/server/src/tools/system/code_runner.rs @@ -12,7 +12,7 @@ use uuid::Uuid; use crate::{ provider::{LlmTool, LlmToolType}, tools::{ - core::{ToolLog, ToolParameters, ToolResponseFormat, ToolResult}, + core::{ToolLog, ToolResponseFormat, ToolResult}, system::{SystemTool, SystemToolConfig}, utils::get_json_schema, ToolError, @@ -105,7 +105,7 @@ impl SystemToolConfig for CodeRunnerConfig { .map_err(|e| ToolError::InvalidConfiguration(e.to_string())) } - fn get_llm_tools(&self, tool_id: Uuid, _input_config: Option<()>) -> Vec { + fn get_llm_tools(&self, tool_id: Uuid, _input_config: Option<&()>) -> Vec { vec![LlmTool { name: CODE_RUNNER_NAME.into(), description: CODE_RUNNER_DESCRIPTION.into(), @@ -118,18 +118,17 @@ impl SystemToolConfig for CodeRunnerConfig { #[async_trait] impl SystemTool for CodeRunner<'_> { - fn input_schema(&self, _tool_name: &str) -> &serde_json::Value { - &CODE_RUNNER_INPUT_SCHEMA + fn input_schema(&self, _tool_name: &str) -> ToolResult<&serde_json::Value> { + Ok(&CODE_RUNNER_INPUT_SCHEMA) } async fn execute( - &self, + &mut self, _tool_name: &str, - params: &ToolParameters, + params: serde_json::Value, sender: &SenderWithLogging, ) -> ToolResult<(String, ToolResponseFormat)> { - let input = serde_json::from_value::(serde_json::to_value(params)?) - .map_err(|e| ToolError::InvalidParameters(e.to_string()))?; + let input = serde_json::from_value::(params)?; let executor = DockerExecutor::new( input.language, DockerExecutorOptions { diff --git a/server/src/tools/system/files.rs b/server/src/tools/system/files.rs new file mode 100644 index 0000000..9d38961 --- /dev/null +++ b/server/src/tools/system/files.rs @@ -0,0 +1,237 @@ +use std::{ + path::{Path, PathBuf}, + sync::LazyLock, +}; + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use crate::{ + db::{ + models::{ChatRsFileType, NewChatRsFile}, + services::FileDbService, + DbConnection, + }, + provider::LlmToolType, + storage::{LocalStorage, DEFAULT_DATA_DIR}, +}; + +use super::*; + +/// Tools for listing, reading, and writing files in the current chat session. +pub struct Files<'a> { + user_id: &'a uuid::Uuid, + session_id: &'a uuid::Uuid, + app_config: &'a AppConfig, + db: &'a mut DbConnection, + config: &'a FilesConfig, +} +impl<'a> Files<'a> { + pub fn new( + user_id: &'a uuid::Uuid, + session_id: &'a uuid::Uuid, + app_config: &'a AppConfig, + db: &'a mut DbConnection, + config: &'a FilesConfig, + ) -> Self { + Self { + user_id, + session_id, + app_config, + db, + config, + } + } +} + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct FilesConfig { + storage: StorageType, +} + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "snake_case")] +enum StorageType { + Local, +} + +#[derive(Debug, Default, PartialEq, JsonSchema, Serialize, Deserialize)] +pub struct FilesInput { + /// Whether assistant has permission to read files + #[serde(default)] + read: bool, + /// Whether assistant has permission to write files + #[serde(default)] + write: bool, +} + +const LIST_FILES: &str = "list_files"; +const LIST_FILES_DESC: &str = "List files for the current chat session"; +static LIST_FILES_SCHEMA: LazyLock = + LazyLock::new(|| utils::get_json_schema::()); + +const READ_FILE: &str = "read_file"; +const READ_FILE_DESC: &str = "Read a file in the current chat session"; +static READ_FILE_SCHEMA: LazyLock = + LazyLock::new(|| utils::get_json_schema::()); + +const WRITE_FILE: &str = "write_file"; +const WRITE_FILE_DESC: &str = "Write a file in the current chat session"; +static WRITE_FILE_SCHEMA: LazyLock = + LazyLock::new(|| utils::get_json_schema::()); + +#[derive(Debug, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +struct ListFilesInput { + /// Optional directory path to list files from + #[schemars(example = "dir_path_example")] + dir: Option, +} +fn dir_path_example() -> &'static str { + "foo/dir" +} + +#[derive(Debug, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +struct ReadFileInput { + /// Path of the file to read. Should be a relative path. + #[schemars(example = "file_path_example")] + path: String, +} + +#[derive(Debug, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +struct WriteFileInput { + /// Path of the file to write. Should be a relative path. + #[schemars(example = "file_path_example")] + path: String, + /// Content to write to the file + content: String, +} +fn file_path_example() -> &'static str { + "foo/file.txt" +} + +impl SystemToolConfig for FilesConfig { + type DynamicConfig = FilesInput; + + fn get_llm_tools( + &self, + tool_id: uuid::Uuid, + input_config: Option<&Self::DynamicConfig>, + ) -> Vec { + let mut tools = Vec::with_capacity(3); + + if input_config.map_or(true, |c| c.read) { + tools.push(LlmTool { + tool_id, + name: READ_FILE.into(), + description: READ_FILE_DESC.into(), + input_schema: READ_FILE_SCHEMA.to_owned(), + tool_type: LlmToolType::System, + }); + tools.push(LlmTool { + tool_id, + name: LIST_FILES.into(), + description: LIST_FILES_DESC.into(), + input_schema: LIST_FILES_SCHEMA.to_owned(), + tool_type: LlmToolType::System, + }); + } + if input_config.map_or(true, |c| c.write) { + tools.push(LlmTool { + tool_id, + name: WRITE_FILE.into(), + description: WRITE_FILE_DESC.into(), + input_schema: WRITE_FILE_SCHEMA.to_owned(), + tool_type: LlmToolType::System, + }); + } + + tools + } + + fn validate(&self) -> ToolResult<()> { + Ok(()) + } +} + +#[async_trait] +impl SystemTool for Files<'_> { + fn input_schema(&self, tool_name: &str) -> ToolResult<&serde_json::Value> { + match tool_name { + READ_FILE => Ok(&READ_FILE_SCHEMA), + LIST_FILES => Ok(&LIST_FILES_SCHEMA), + WRITE_FILE => Ok(&WRITE_FILE_SCHEMA), + _ => Err(ToolError::ToolNotFound), + } + } + + async fn execute( + &mut self, + tool_name: &str, + parameters: serde_json::Value, + _sender: &SenderWithLogging, + ) -> ToolResult<(String, ToolResponseFormat)> { + let storage = match self.config.storage { + StorageType::Local => { + let data_dir = self + .app_config + .data_dir + .as_deref() + .unwrap_or(DEFAULT_DATA_DIR); + LocalStorage::new(PathBuf::from(data_dir).join("storage")) + } + }; + + match tool_name { + READ_FILE => { + let input: ReadFileInput = serde_json::from_value(parameters)?; + let path = Path::new(&input.path); + let content_bytes = storage + .read_file_as_bytes(self.user_id, Some(self.session_id), path) + .await?; + let content = String::from_utf8_lossy(&content_bytes); + Ok((content.into(), ToolResponseFormat::Text)) + } + LIST_FILES => { + let input: ListFilesInput = + serde_json::from_value(serde_json::to_value(parameters)?)?; + let mut files = FileDbService::new(self.db) + .list_session_files(self.user_id, self.session_id) + .await?; + if let Some(dir) = input.dir { + files.retain(|file| file.path.starts_with(&dir)); + } + + Ok((serde_json::to_string(&files)?, ToolResponseFormat::Json)) + } + WRITE_FILE => { + let input: WriteFileInput = + serde_json::from_value(serde_json::to_value(parameters)?)?; + let size = storage + .create_file( + self.user_id, + Some(self.session_id), + &Path::new(&input.path), + input.content.as_bytes(), + ) + .await?; + let file = FileDbService::new(self.db) + .create_session_file(NewChatRsFile { + user_id: self.user_id, + session_id: Some(self.session_id), + path: &input.path, + file_type: ChatRsFileType::Text.into(), + content_type: "text/plain", + size: size.try_into().unwrap_or_default(), + }) + .await?; + + let message = format!("File '{}' created with ID {}", input.path, file.id); + Ok((message, ToolResponseFormat::Text)) + } + _ => Err(ToolError::ToolNotFound), + } + } +} diff --git a/server/src/tools/system/system_info.rs b/server/src/tools/system/system_info.rs index 370c387..4deec5d 100644 --- a/server/src/tools/system/system_info.rs +++ b/server/src/tools/system/system_info.rs @@ -4,16 +4,18 @@ use rocket::async_trait; use schemars::JsonSchema; use crate::{ - provider::{LlmTool, LlmToolType}, + config::AppConfig, + provider::LlmToolType, tools::{system::SystemToolConfig, utils::get_json_schema}, utils::SenderWithLogging, }; -use super::{SystemTool, ToolError, ToolLog, ToolParameters, ToolResponseFormat, ToolResult}; +use super::*; const TOOL_PREFIX: &str = "system_"; -static JSON_SCHEMA: LazyLock = LazyLock::new(|| get_json_schema::()); +static JSON_SCHEMA: LazyLock = + LazyLock::new(|| get_json_schema::()); const DATE_TIME_NAME: &str = "datetime_now"; const DATE_TIME_DESC: &str = "Get the current date and time in RFC3339 format. \ @@ -24,23 +26,26 @@ const SERVER_URL_DESC: &str = "Get the URL of the server that this chat applicat This may be useful to help direct the user to files or other resources that are hosted on the server."; /// Tool to get system information. -#[derive(Debug, JsonSchema)] -#[serde(deny_unknown_fields)] -pub struct SystemInfo {} -impl SystemInfo { - pub fn new() -> Self { - SystemInfo {} +pub struct SystemInfo<'a> { + app_config: &'a AppConfig, +} +impl<'a> SystemInfo<'a> { + pub fn new(app_config: &'a AppConfig) -> Self { + SystemInfo { app_config } } } +#[derive(Debug, JsonSchema)] +#[serde(deny_unknown_fields)] pub struct SystemInfoConfig {} + impl SystemToolConfig for SystemInfoConfig { type DynamicConfig = (); fn get_llm_tools( &self, tool_id: uuid::Uuid, - _input_config: Option, + _input_config: Option<&Self::DynamicConfig>, ) -> Vec { vec![ LlmTool { @@ -66,15 +71,15 @@ impl SystemToolConfig for SystemInfoConfig { } #[async_trait] -impl SystemTool for SystemInfo { - fn input_schema(&self, _tool_name: &str) -> &serde_json::Value { - &JSON_SCHEMA +impl SystemTool for SystemInfo<'_> { + fn input_schema(&self, _tool_name: &str) -> ToolResult<&serde_json::Value> { + Ok(&JSON_SCHEMA) } async fn execute( - &self, + &mut self, tool_name: &str, - _params: &ToolParameters, + _params: serde_json::Value, _tx: &SenderWithLogging, ) -> ToolResult<(String, ToolResponseFormat)> { match tool_name.strip_prefix(TOOL_PREFIX) { @@ -82,12 +87,10 @@ impl SystemTool for SystemInfo { let now = chrono::Utc::now(); Ok((now.to_rfc3339(), ToolResponseFormat::Text)) } - Some(SERVER_URL_NAME) => { - let server_url = std::env::var("RS_CHAT_SERVER_ADDRESS").map_err(|_| { - ToolError::ToolExecutionError("Could not determine Server URL".into()) - })?; - Ok((server_url, ToolResponseFormat::Text)) - } + Some(SERVER_URL_NAME) => Ok(( + self.app_config.server_address.clone(), + ToolResponseFormat::Text, + )), _ => Err(ToolError::ToolNotFound), } } diff --git a/web/package.json b/web/package.json index bfffb6e..5bfb058 100644 --- a/web/package.json +++ b/web/package.json @@ -42,6 +42,7 @@ "openapi-fetch": "^0.14.0", "react": "^19.0.0", "react-dom": "^19.0.0", + "react-dropzone": "^14.3.8", "react-markdown": "^10.1.0", "rehype-highlight": "^7.0.2", "rehype-highlight-code-lines": "^1.1.5", @@ -54,7 +55,7 @@ "vaul": "^1.1.2" }, "devDependencies": { - "@biomejs/biome": "2.0.0", + "@biomejs/biome": "2.2.2", "@tailwindcss/typography": "^0.5.16", "@testing-library/dom": "^10.4.0", "@testing-library/react": "^16.2.0", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 0e4906b..3560a2a 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -86,6 +86,9 @@ importers: react-dom: specifier: ^19.0.0 version: 19.1.0(react@19.1.0) + react-dropzone: + specifier: ^14.3.8 + version: 14.3.8(react@19.1.0) react-markdown: specifier: ^10.1.0 version: 10.1.0(@types/react@19.1.8)(react@19.1.0) @@ -118,8 +121,8 @@ importers: version: 1.1.2(@types/react-dom@19.1.6(@types/react@19.1.8))(@types/react@19.1.8)(react-dom@19.1.0(react@19.1.0))(react@19.1.0) devDependencies: '@biomejs/biome': - specifier: 2.0.0 - version: 2.0.0 + specifier: 2.2.2 + version: 2.2.2 '@tailwindcss/typography': specifier: ^0.5.16 version: 0.5.16(tailwindcss@4.1.10) @@ -307,55 +310,55 @@ packages: resolution: {integrity: sha512-ETyHEk2VHHvl9b9jZP5IHPavHYk57EhanlRRuae9XCpb/j5bDCbPPMOBfCWhnl/7EDJz0jEMCi/RhccCE8r1+Q==} engines: {node: '>=6.9.0'} - '@biomejs/biome@2.0.0': - resolution: {integrity: sha512-BlUoXEOI/UQTDEj/pVfnkMo8SrZw3oOWBDrXYFT43V7HTkIUDkBRY53IC5Jx1QkZbaB+0ai1wJIfYwp9+qaJTQ==} + '@biomejs/biome@2.2.2': + resolution: {integrity: sha512-j1omAiQWCkhuLgwpMKisNKnsM6W8Xtt1l0WZmqY/dFj8QPNkIoTvk4tSsi40FaAAkBE1PU0AFG2RWFBWenAn+w==} engines: {node: '>=14.21.3'} hasBin: true - '@biomejs/cli-darwin-arm64@2.0.0': - resolution: {integrity: sha512-QvqWYtFFhhxdf8jMAdJzXW+Frc7X8XsnHQLY+TBM1fnT1TfeV/v9vsFI5L2J7GH6qN1+QEEJ19jHibCY2Ypplw==} + '@biomejs/cli-darwin-arm64@2.2.2': + resolution: {integrity: sha512-6ePfbCeCPryWu0CXlzsWNZgVz/kBEvHiPyNpmViSt6A2eoDf4kXs3YnwQPzGjy8oBgQulrHcLnJL0nkCh80mlQ==} engines: {node: '>=14.21.3'} cpu: [arm64] os: [darwin] - '@biomejs/cli-darwin-x64@2.0.0': - resolution: {integrity: sha512-5JFhls1EfmuIH4QGFPlNpxJQFC6ic3X1ltcoLN+eSRRIPr6H/lUS1ttuD0Fj7rPgPhZqopK/jfH8UVj/1hIsQw==} + '@biomejs/cli-darwin-x64@2.2.2': + resolution: {integrity: sha512-Tn4JmVO+rXsbRslml7FvKaNrlgUeJot++FkvYIhl1OkslVCofAtS35MPlBMhXgKWF9RNr9cwHanrPTUUXcYGag==} engines: {node: '>=14.21.3'} cpu: [x64] os: [darwin] - '@biomejs/cli-linux-arm64-musl@2.0.0': - resolution: {integrity: sha512-Bxsz8ki8+b3PytMnS5SgrGV+mbAWwIxI3ydChb/d1rURlJTMdxTTq5LTebUnlsUWAX6OvJuFeiVq9Gjn1YbCyA==} + '@biomejs/cli-linux-arm64-musl@2.2.2': + resolution: {integrity: sha512-/MhYg+Bd6renn6i1ylGFL5snYUn/Ct7zoGVKhxnro3bwekiZYE8Kl39BSb0MeuqM+72sThkQv4TnNubU9njQRw==} engines: {node: '>=14.21.3'} cpu: [arm64] os: [linux] - '@biomejs/cli-linux-arm64@2.0.0': - resolution: {integrity: sha512-BAH4QVi06TzAbVchXdJPsL0Z/P87jOfes15rI+p3EX9/EGTfIjaQ9lBVlHunxcmoptaA5y1Hdb9UYojIhmnjIw==} + '@biomejs/cli-linux-arm64@2.2.2': + resolution: {integrity: sha512-JfrK3gdmWWTh2J5tq/rcWCOsImVyzUnOS2fkjhiYKCQ+v8PqM+du5cfB7G1kXas+7KQeKSWALv18iQqdtIMvzw==} engines: {node: '>=14.21.3'} cpu: [arm64] os: [linux] - '@biomejs/cli-linux-x64-musl@2.0.0': - resolution: {integrity: sha512-tiQ0ABxMJb9I6GlfNp0ulrTiQSFacJRJO8245FFwE3ty3bfsfxlU/miblzDIi+qNrgGsLq5wIZcVYGp4c+HXZA==} + '@biomejs/cli-linux-x64-musl@2.2.2': + resolution: {integrity: sha512-ZCLXcZvjZKSiRY/cFANKg+z6Fhsf9MHOzj+NrDQcM+LbqYRT97LyCLWy2AS+W2vP+i89RyRM+kbGpUzbRTYWig==} engines: {node: '>=14.21.3'} cpu: [x64] os: [linux] - '@biomejs/cli-linux-x64@2.0.0': - resolution: {integrity: sha512-09PcOGYTtkopWRm6mZ/B6Mr6UHdkniUgIG/jLBv+2J8Z61ezRE+xQmpi3yNgUrFIAU4lPA9atg7mhvE/5Bo7Wg==} + '@biomejs/cli-linux-x64@2.2.2': + resolution: {integrity: sha512-Ogb+77edO5LEP/xbNicACOWVLt8mgC+E1wmpUakr+O4nKwLt9vXe74YNuT3T1dUBxC/SnrVmlzZFC7kQJEfquQ==} engines: {node: '>=14.21.3'} cpu: [x64] os: [linux] - '@biomejs/cli-win32-arm64@2.0.0': - resolution: {integrity: sha512-vrTtuGu91xNTEQ5ZcMJBZuDlqr32DWU1r14UfePIGndF//s2WUAmer4FmgoPgruo76rprk37e8S2A2c0psXdxw==} + '@biomejs/cli-win32-arm64@2.2.2': + resolution: {integrity: sha512-wBe2wItayw1zvtXysmHJQoQqXlTzHSpQRyPpJKiNIR21HzH/CrZRDFic1C1jDdp+zAPtqhNExa0owKMbNwW9cQ==} engines: {node: '>=14.21.3'} cpu: [arm64] os: [win32] - '@biomejs/cli-win32-x64@2.0.0': - resolution: {integrity: sha512-2USVQ0hklNsph/KIR72ZdeptyXNnQ3JdzPn3NbjI4Sna34CnxeiYAaZcZzXPDl5PYNFBivV4xmvT3Z3rTmyDBg==} + '@biomejs/cli-win32-x64@2.2.2': + resolution: {integrity: sha512-DAuHhHekGfiGb6lCcsT4UyxQmVwQiBCBUMwVra/dcOSs9q8OhfaZgey51MlekT3p8UwRqtXQfFuEJBhJNdLZwg==} engines: {node: '>=14.21.3'} cpu: [x64] os: [win32] @@ -1435,6 +1438,10 @@ packages: resolution: {integrity: sha512-6t10qk83GOG8p0vKmaCr8eiilZwO171AvbROMtvvNiwrTly62t+7XkA8RdIIVbpMhCASAsxgAzdRSwh6nw/5Dg==} engines: {node: '>=4'} + attr-accept@2.2.5: + resolution: {integrity: sha512-0bDNnY/u6pPwHDMoF0FieU354oBi0a8rD9FcsLwzcGWbc8KS8KPIi7y+s13OlVY+gMWc/9xEMUgNE6Qm8ZllYQ==} + engines: {node: '>=4'} + babel-dead-code-elimination@1.0.10: resolution: {integrity: sha512-DV5bdJZTzZ0zn0DC24v3jD7Mnidh6xhKa4GfKCbq3sfW8kaWhDdZjP3i81geA8T33tdYqWKw4D3fVv0CwEgKVA==} @@ -1635,6 +1642,10 @@ packages: picomatch: optional: true + file-selector@2.1.2: + resolution: {integrity: sha512-QgXo+mXTe8ljeqUFaX3QVHc5osSItJ/Km+xpocx0aSqWGMSCf6qYs/VnzZgS864Pjn5iceMRFigeAV7AfTlaig==} + engines: {node: '>= 12'} + fill-range@7.1.1: resolution: {integrity: sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==} engines: {node: '>=8'} @@ -1856,6 +1867,10 @@ packages: longest-streak@3.1.0: resolution: {integrity: sha512-9Ri+o0JYgehTaVBBDoMqIl8GXtbWg711O3srftcHhZ0dqnETqLaoIK0x17fUw9rFSlK/0NlsKe0Ahhyl5pXE2g==} + loose-envify@1.4.0: + resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} + hasBin: true + loupe@3.1.3: resolution: {integrity: sha512-kkIp7XSkP78ZxJEsSxW3712C6teJVoeHHwgo9zJ380de7IYyJ2ISlxojcH2pC5OFLewESmnRi/+XCDIEEVyoug==} @@ -2043,6 +2058,10 @@ packages: nwsapi@2.2.20: resolution: {integrity: sha512-/ieB+mDe4MrrKMT8z+mQL8klXydZWGR5Dowt4RAGKbJ3kIGEx3X4ljUo+6V73IXtUPWgfOlU5B9MlGxFO5T+cA==} + object-assign@4.1.1: + resolution: {integrity: sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==} + engines: {node: '>=0.10.0'} + openapi-fetch@0.14.0: resolution: {integrity: sha512-PshIdm1NgdLvb05zp8LqRQMNSKzIlPkyMxYFxwyHR+UlKD4t2nUjkDhNxeRbhRSEd3x5EUNh2w5sJYwkhOH4fg==} @@ -2093,6 +2112,9 @@ packages: resolution: {integrity: sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==} engines: {node: ^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0} + prop-types@15.8.1: + resolution: {integrity: sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==} + property-information@7.1.0: resolution: {integrity: sha512-TwEZ+X+yCJmYfL7TPUOcvBZ4QfoT5YenQiJuX//0th53DE6w0xxLEtfK3iyryQFddXuvkIk51EEgrJQ0WJkOmQ==} @@ -2105,6 +2127,15 @@ packages: peerDependencies: react: ^19.1.0 + react-dropzone@14.3.8: + resolution: {integrity: sha512-sBgODnq+lcA4P296DY4wacOZz3JFpD99fp+hb//iBO2HHnyeZU3FwWyXJ6salNpqQdsZrgMrotuko/BdJMV8Ug==} + engines: {node: '>= 10.13'} + peerDependencies: + react: '>= 16.8 || 18.0.0' + + react-is@16.13.1: + resolution: {integrity: sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==} + react-is@17.0.2: resolution: {integrity: sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==} @@ -2776,39 +2807,39 @@ snapshots: '@babel/helper-string-parser': 7.27.1 '@babel/helper-validator-identifier': 7.27.1 - '@biomejs/biome@2.0.0': + '@biomejs/biome@2.2.2': optionalDependencies: - '@biomejs/cli-darwin-arm64': 2.0.0 - '@biomejs/cli-darwin-x64': 2.0.0 - '@biomejs/cli-linux-arm64': 2.0.0 - '@biomejs/cli-linux-arm64-musl': 2.0.0 - '@biomejs/cli-linux-x64': 2.0.0 - '@biomejs/cli-linux-x64-musl': 2.0.0 - '@biomejs/cli-win32-arm64': 2.0.0 - '@biomejs/cli-win32-x64': 2.0.0 - - '@biomejs/cli-darwin-arm64@2.0.0': + '@biomejs/cli-darwin-arm64': 2.2.2 + '@biomejs/cli-darwin-x64': 2.2.2 + '@biomejs/cli-linux-arm64': 2.2.2 + '@biomejs/cli-linux-arm64-musl': 2.2.2 + '@biomejs/cli-linux-x64': 2.2.2 + '@biomejs/cli-linux-x64-musl': 2.2.2 + '@biomejs/cli-win32-arm64': 2.2.2 + '@biomejs/cli-win32-x64': 2.2.2 + + '@biomejs/cli-darwin-arm64@2.2.2': optional: true - '@biomejs/cli-darwin-x64@2.0.0': + '@biomejs/cli-darwin-x64@2.2.2': optional: true - '@biomejs/cli-linux-arm64-musl@2.0.0': + '@biomejs/cli-linux-arm64-musl@2.2.2': optional: true - '@biomejs/cli-linux-arm64@2.0.0': + '@biomejs/cli-linux-arm64@2.2.2': optional: true - '@biomejs/cli-linux-x64-musl@2.0.0': + '@biomejs/cli-linux-x64-musl@2.2.2': optional: true - '@biomejs/cli-linux-x64@2.0.0': + '@biomejs/cli-linux-x64@2.2.2': optional: true - '@biomejs/cli-win32-arm64@2.0.0': + '@biomejs/cli-win32-arm64@2.2.2': optional: true - '@biomejs/cli-win32-x64@2.0.0': + '@biomejs/cli-win32-x64@2.2.2': optional: true '@csstools/color-helpers@5.0.2': {} @@ -3786,6 +3817,8 @@ snapshots: dependencies: tslib: 2.8.1 + attr-accept@2.2.5: {} + babel-dead-code-elimination@1.0.10: dependencies: '@babel/core': 7.27.4 @@ -3982,6 +4015,10 @@ snapshots: optionalDependencies: picomatch: 4.0.2 + file-selector@2.1.2: + dependencies: + tslib: 2.8.1 + fill-range@7.1.1: dependencies: to-regex-range: 5.0.1 @@ -4193,6 +4230,10 @@ snapshots: longest-streak@3.1.0: {} + loose-envify@1.4.0: + dependencies: + js-tokens: 4.0.0 + loupe@3.1.3: {} lowlight@3.3.0: @@ -4581,6 +4622,8 @@ snapshots: nwsapi@2.2.20: {} + object-assign@4.1.1: {} + openapi-fetch@0.14.0: dependencies: openapi-typescript-helpers: 0.0.15 @@ -4632,6 +4675,12 @@ snapshots: ansi-styles: 5.2.0 react-is: 17.0.2 + prop-types@15.8.1: + dependencies: + loose-envify: 1.4.0 + object-assign: 4.1.1 + react-is: 16.13.1 + property-information@7.1.0: {} punycode@2.3.1: {} @@ -4641,6 +4690,15 @@ snapshots: react: 19.1.0 scheduler: 0.26.0 + react-dropzone@14.3.8(react@19.1.0): + dependencies: + attr-accept: 2.2.5 + file-selector: 2.1.2 + prop-types: 15.8.1 + react: 19.1.0 + + react-is@16.13.1: {} + react-is@17.0.2: {} react-markdown@10.1.0(@types/react@19.1.8)(react@19.1.0): diff --git a/web/src/components/ApiKeysManager.tsx b/web/src/components/ApiKeysManager.tsx index 1945851..1f0a917 100644 --- a/web/src/components/ApiKeysManager.tsx +++ b/web/src/components/ApiKeysManager.tsx @@ -1,5 +1,5 @@ import { Bot, Check, Copy, ExternalLink, Plus, Trash2 } from "lucide-react"; -import { useState } from "react"; +import { useId, useState } from "react"; import { AlertDialog, @@ -80,6 +80,9 @@ export function ApiKeysManager({ }); }; + const nameId = useId(); + const valueId = useId(); + return (
- + {newApiKeyValue && (
-