From 28bddec1a8c3633eb63ce27b4250ec3cd7232824 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 29 Aug 2025 20:54:41 -0400 Subject: [PATCH 01/26] web: update biome and fix lint errors --- biome.json | 2 +- web/package.json | 2 +- web/pnpm-lock.yaml | 76 +++++++++++++-------------- web/src/components/ApiKeysManager.tsx | 13 +++-- web/src/components/SearchForm.tsx | 7 ++- web/src/styles.css | 12 +++-- 6 files changed, 60 insertions(+), 52 deletions(-) diff --git a/biome.json b/biome.json index 1653499..840301a 100644 --- a/biome.json +++ b/biome.json @@ -1,5 +1,5 @@ { - "$schema": "https://biomejs.dev/schemas/2.0.0/schema.json", + "$schema": "https://biomejs.dev/schemas/2.2.2/schema.json", "vcs": { "enabled": true, "clientKind": "git", diff --git a/web/package.json b/web/package.json index bfffb6e..6c2f8f3 100644 --- a/web/package.json +++ b/web/package.json @@ -54,7 +54,7 @@ "vaul": "^1.1.2" }, "devDependencies": { - "@biomejs/biome": "2.0.0", + "@biomejs/biome": "2.2.2", "@tailwindcss/typography": "^0.5.16", "@testing-library/dom": "^10.4.0", "@testing-library/react": "^16.2.0", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 0e4906b..801b186 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -118,8 +118,8 @@ importers: version: 1.1.2(@types/react-dom@19.1.6(@types/react@19.1.8))(@types/react@19.1.8)(react-dom@19.1.0(react@19.1.0))(react@19.1.0) devDependencies: '@biomejs/biome': - specifier: 2.0.0 - version: 2.0.0 + specifier: 2.2.2 + version: 2.2.2 '@tailwindcss/typography': specifier: ^0.5.16 version: 0.5.16(tailwindcss@4.1.10) @@ -307,55 +307,55 @@ packages: resolution: {integrity: sha512-ETyHEk2VHHvl9b9jZP5IHPavHYk57EhanlRRuae9XCpb/j5bDCbPPMOBfCWhnl/7EDJz0jEMCi/RhccCE8r1+Q==} engines: {node: '>=6.9.0'} - '@biomejs/biome@2.0.0': - resolution: {integrity: sha512-BlUoXEOI/UQTDEj/pVfnkMo8SrZw3oOWBDrXYFT43V7HTkIUDkBRY53IC5Jx1QkZbaB+0ai1wJIfYwp9+qaJTQ==} + '@biomejs/biome@2.2.2': + resolution: {integrity: sha512-j1omAiQWCkhuLgwpMKisNKnsM6W8Xtt1l0WZmqY/dFj8QPNkIoTvk4tSsi40FaAAkBE1PU0AFG2RWFBWenAn+w==} engines: {node: '>=14.21.3'} hasBin: true - '@biomejs/cli-darwin-arm64@2.0.0': - resolution: {integrity: sha512-QvqWYtFFhhxdf8jMAdJzXW+Frc7X8XsnHQLY+TBM1fnT1TfeV/v9vsFI5L2J7GH6qN1+QEEJ19jHibCY2Ypplw==} + '@biomejs/cli-darwin-arm64@2.2.2': + resolution: {integrity: sha512-6ePfbCeCPryWu0CXlzsWNZgVz/kBEvHiPyNpmViSt6A2eoDf4kXs3YnwQPzGjy8oBgQulrHcLnJL0nkCh80mlQ==} engines: {node: '>=14.21.3'} cpu: [arm64] os: [darwin] - '@biomejs/cli-darwin-x64@2.0.0': - resolution: {integrity: sha512-5JFhls1EfmuIH4QGFPlNpxJQFC6ic3X1ltcoLN+eSRRIPr6H/lUS1ttuD0Fj7rPgPhZqopK/jfH8UVj/1hIsQw==} + '@biomejs/cli-darwin-x64@2.2.2': + resolution: {integrity: sha512-Tn4JmVO+rXsbRslml7FvKaNrlgUeJot++FkvYIhl1OkslVCofAtS35MPlBMhXgKWF9RNr9cwHanrPTUUXcYGag==} engines: {node: '>=14.21.3'} cpu: [x64] os: [darwin] - '@biomejs/cli-linux-arm64-musl@2.0.0': - resolution: {integrity: sha512-Bxsz8ki8+b3PytMnS5SgrGV+mbAWwIxI3ydChb/d1rURlJTMdxTTq5LTebUnlsUWAX6OvJuFeiVq9Gjn1YbCyA==} + '@biomejs/cli-linux-arm64-musl@2.2.2': + resolution: {integrity: sha512-/MhYg+Bd6renn6i1ylGFL5snYUn/Ct7zoGVKhxnro3bwekiZYE8Kl39BSb0MeuqM+72sThkQv4TnNubU9njQRw==} engines: {node: '>=14.21.3'} cpu: [arm64] os: [linux] - '@biomejs/cli-linux-arm64@2.0.0': - resolution: {integrity: sha512-BAH4QVi06TzAbVchXdJPsL0Z/P87jOfes15rI+p3EX9/EGTfIjaQ9lBVlHunxcmoptaA5y1Hdb9UYojIhmnjIw==} + '@biomejs/cli-linux-arm64@2.2.2': + resolution: {integrity: sha512-JfrK3gdmWWTh2J5tq/rcWCOsImVyzUnOS2fkjhiYKCQ+v8PqM+du5cfB7G1kXas+7KQeKSWALv18iQqdtIMvzw==} engines: {node: '>=14.21.3'} cpu: [arm64] os: [linux] - '@biomejs/cli-linux-x64-musl@2.0.0': - resolution: {integrity: sha512-tiQ0ABxMJb9I6GlfNp0ulrTiQSFacJRJO8245FFwE3ty3bfsfxlU/miblzDIi+qNrgGsLq5wIZcVYGp4c+HXZA==} + '@biomejs/cli-linux-x64-musl@2.2.2': + resolution: {integrity: sha512-ZCLXcZvjZKSiRY/cFANKg+z6Fhsf9MHOzj+NrDQcM+LbqYRT97LyCLWy2AS+W2vP+i89RyRM+kbGpUzbRTYWig==} engines: {node: '>=14.21.3'} cpu: [x64] os: [linux] - '@biomejs/cli-linux-x64@2.0.0': - resolution: {integrity: sha512-09PcOGYTtkopWRm6mZ/B6Mr6UHdkniUgIG/jLBv+2J8Z61ezRE+xQmpi3yNgUrFIAU4lPA9atg7mhvE/5Bo7Wg==} + '@biomejs/cli-linux-x64@2.2.2': + resolution: {integrity: sha512-Ogb+77edO5LEP/xbNicACOWVLt8mgC+E1wmpUakr+O4nKwLt9vXe74YNuT3T1dUBxC/SnrVmlzZFC7kQJEfquQ==} engines: {node: '>=14.21.3'} cpu: [x64] os: [linux] - '@biomejs/cli-win32-arm64@2.0.0': - resolution: {integrity: sha512-vrTtuGu91xNTEQ5ZcMJBZuDlqr32DWU1r14UfePIGndF//s2WUAmer4FmgoPgruo76rprk37e8S2A2c0psXdxw==} + '@biomejs/cli-win32-arm64@2.2.2': + resolution: {integrity: sha512-wBe2wItayw1zvtXysmHJQoQqXlTzHSpQRyPpJKiNIR21HzH/CrZRDFic1C1jDdp+zAPtqhNExa0owKMbNwW9cQ==} engines: {node: '>=14.21.3'} cpu: [arm64] os: [win32] - '@biomejs/cli-win32-x64@2.0.0': - resolution: {integrity: sha512-2USVQ0hklNsph/KIR72ZdeptyXNnQ3JdzPn3NbjI4Sna34CnxeiYAaZcZzXPDl5PYNFBivV4xmvT3Z3rTmyDBg==} + '@biomejs/cli-win32-x64@2.2.2': + resolution: {integrity: sha512-DAuHhHekGfiGb6lCcsT4UyxQmVwQiBCBUMwVra/dcOSs9q8OhfaZgey51MlekT3p8UwRqtXQfFuEJBhJNdLZwg==} engines: {node: '>=14.21.3'} cpu: [x64] os: [win32] @@ -2776,39 +2776,39 @@ snapshots: '@babel/helper-string-parser': 7.27.1 '@babel/helper-validator-identifier': 7.27.1 - '@biomejs/biome@2.0.0': + '@biomejs/biome@2.2.2': optionalDependencies: - '@biomejs/cli-darwin-arm64': 2.0.0 - '@biomejs/cli-darwin-x64': 2.0.0 - '@biomejs/cli-linux-arm64': 2.0.0 - '@biomejs/cli-linux-arm64-musl': 2.0.0 - '@biomejs/cli-linux-x64': 2.0.0 - '@biomejs/cli-linux-x64-musl': 2.0.0 - '@biomejs/cli-win32-arm64': 2.0.0 - '@biomejs/cli-win32-x64': 2.0.0 - - '@biomejs/cli-darwin-arm64@2.0.0': + '@biomejs/cli-darwin-arm64': 2.2.2 + '@biomejs/cli-darwin-x64': 2.2.2 + '@biomejs/cli-linux-arm64': 2.2.2 + '@biomejs/cli-linux-arm64-musl': 2.2.2 + '@biomejs/cli-linux-x64': 2.2.2 + '@biomejs/cli-linux-x64-musl': 2.2.2 + '@biomejs/cli-win32-arm64': 2.2.2 + '@biomejs/cli-win32-x64': 2.2.2 + + '@biomejs/cli-darwin-arm64@2.2.2': optional: true - '@biomejs/cli-darwin-x64@2.0.0': + '@biomejs/cli-darwin-x64@2.2.2': optional: true - '@biomejs/cli-linux-arm64-musl@2.0.0': + '@biomejs/cli-linux-arm64-musl@2.2.2': optional: true - '@biomejs/cli-linux-arm64@2.0.0': + '@biomejs/cli-linux-arm64@2.2.2': optional: true - '@biomejs/cli-linux-x64-musl@2.0.0': + '@biomejs/cli-linux-x64-musl@2.2.2': optional: true - '@biomejs/cli-linux-x64@2.0.0': + '@biomejs/cli-linux-x64@2.2.2': optional: true - '@biomejs/cli-win32-arm64@2.0.0': + '@biomejs/cli-win32-arm64@2.2.2': optional: true - '@biomejs/cli-win32-x64@2.0.0': + '@biomejs/cli-win32-x64@2.2.2': optional: true '@csstools/color-helpers@5.0.2': {} diff --git a/web/src/components/ApiKeysManager.tsx b/web/src/components/ApiKeysManager.tsx index 1945851..1f0a917 100644 --- a/web/src/components/ApiKeysManager.tsx +++ b/web/src/components/ApiKeysManager.tsx @@ -1,5 +1,5 @@ import { Bot, Check, Copy, ExternalLink, Plus, Trash2 } from "lucide-react"; -import { useState } from "react"; +import { useId, useState } from "react"; import { AlertDialog, @@ -80,6 +80,9 @@ export function ApiKeysManager({ }); }; + const nameId = useId(); + const valueId = useId(); + return (
- + {newApiKeyValue && (
-
diff --git a/web/src/components/chat/settings/ChatFileSelect.tsx b/web/src/components/chat/settings/ChatFileSelect.tsx new file mode 100644 index 0000000..23366d3 --- /dev/null +++ b/web/src/components/chat/settings/ChatFileSelect.tsx @@ -0,0 +1,203 @@ +import { Check, File, FileText, Image, Paperclip } from "lucide-react"; +import { useCallback, useMemo, useState } from "react"; + +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "@/components/ui/dialog"; +import { useSessionFiles } from "@/lib/api/storage"; +import type { components } from "@/lib/api/types"; +import { cn } from "@/lib/utils"; +import ChatSettingsBadge from "./ChatSettingsBadge"; + +interface FileSelectionDialogProps { + sessionId: string; + selectedFiles: components["schemas"]["ChatRsFile"][]; + onAddFile: (file: components["schemas"]["ChatRsFile"]) => void; + onRemoveFile: (fileId: string) => void; + onRemoveAllFiles: () => void; +} + +function getFileIcon(fileType: components["schemas"]["ChatRsFileType"]) { + switch (fileType) { + case "image": + return ; + case "pdf": + return ; + default: + return ; + } +} + +function formatFileSize(bytes: number): string { + if (bytes === 0) return "0 B"; + const k = 1024; + const sizes = ["B", "KB", "MB", "GB"]; + const i = Math.floor(Math.log(bytes) / Math.log(k)); + return `${parseFloat((bytes / k ** i).toFixed(1))} ${sizes[i]}`; +} + +export default function ChatFileSelect({ + sessionId, + selectedFiles, + onAddFile, + onRemoveFile, + onRemoveAllFiles, +}: FileSelectionDialogProps) { + const [open, setOpen] = useState(false); + const { data: files, isLoading } = useSessionFiles(sessionId); + + const fileList = useMemo(() => files || [], [files]); + + const handleFileToggle = useCallback( + (file: components["schemas"]["ChatRsFile"], isSelected: boolean) => { + if (isSelected) { + onRemoveFile(file.id); + } else { + onAddFile(file); + } + }, + [onAddFile, onRemoveFile], + ); + + const handleSelectAll = useCallback(() => { + fileList.forEach((file) => { + if (!selectedFiles.some((f) => f.id === file.id)) { + onAddFile(file); + } + }); + }, [fileList, selectedFiles, onAddFile]); + + const handleDeselectAll = useCallback(() => { + onRemoveAllFiles(); + }, [onRemoveAllFiles]); + + const selectedCount = selectedFiles.length; + const totalCount = fileList.length; + + return ( + + + + + + + Attach Files + + Select files from your session to attach to your message. + + + +
+ {totalCount > 0 && ( +
+ + {selectedCount} of {totalCount} files selected + +
+ + +
+
+ )} + +
+ {isLoading ? ( +
+ Loading files... +
+ ) : fileList.length === 0 ? ( +
+ +

No files in this session

+

+ Upload files by dragging them to the chat input +

+
+ ) : ( +
+ {fileList.map((file) => { + const isSelected = selectedFiles.some( + (f) => f.id === file.id, + ); + return ( + + ); + })} +
+ )} +
+ +
+ +
+
+
+
+ ); +} diff --git a/web/src/components/chat/settings/ChatSettingsBadge.tsx b/web/src/components/chat/settings/ChatSettingsBadge.tsx new file mode 100644 index 0000000..fae864f --- /dev/null +++ b/web/src/components/chat/settings/ChatSettingsBadge.tsx @@ -0,0 +1,13 @@ +import { Badge } from "@/components/ui/badge"; + +export default function ChatSettingsBadge({ + children, +}: { + children: React.ReactNode; +}) { + return ( + + {children} + + ); +} diff --git a/web/src/components/chat/settings/ChatToolSelect.tsx b/web/src/components/chat/settings/ChatToolSelect.tsx index 2f24889..6368ef5 100644 --- a/web/src/components/chat/settings/ChatToolSelect.tsx +++ b/web/src/components/chat/settings/ChatToolSelect.tsx @@ -3,7 +3,6 @@ import { useMemo, useState } from "react"; import PopoverDrawer from "@/components/PopoverDrawer"; import { getToolIcon } from "@/components/ToolsManager"; -import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Command, @@ -15,6 +14,7 @@ import { } from "@/components/ui/command"; import type { useChatInputState } from "@/hooks/useChatInputState"; import type { components } from "@/lib/api/types"; +import ChatSettingsBadge from "./ChatSettingsBadge"; export default function ChatToolSelect({ toolInput, @@ -88,11 +88,10 @@ export default function ChatToolSelect({ trigger={ } > diff --git a/web/src/components/chat/settings/index.ts b/web/src/components/chat/settings/index.ts index c409721..d10b785 100644 --- a/web/src/components/chat/settings/index.ts +++ b/web/src/components/chat/settings/index.ts @@ -1,3 +1,4 @@ +import ChatFileSelect from "./ChatFileSelect"; import ChatModelSelect from "./ChatModelSelect"; import ChatMoreSettings from "./ChatMoreSettings"; import ChatProviderSelect from "./ChatProviderSelect"; @@ -5,6 +6,7 @@ import ChatToolSelect from "./ChatToolSelect"; export { ChatModelSelect, + ChatFileSelect, ChatProviderSelect, ChatToolSelect, ChatMoreSettings, diff --git a/web/src/hooks/useChatInputState.tsx b/web/src/hooks/useChatInputState.tsx index 99031d7..fdad653 100644 --- a/web/src/hooks/useChatInputState.tsx +++ b/web/src/hooks/useChatInputState.tsx @@ -44,7 +44,7 @@ export const useChatInputState = ({ const [toolInput, setToolInput] = useState< components["schemas"]["SendChatToolInput"] | null >(initialTools || DEFAULT_TOOL_INPUT); - const [files, setFiles] = useState([]); + const [files, setFiles] = useState([]); const [maxTokens, setMaxTokens] = useState( initialOptions?.max_tokens ?? DEFAULT_MAX_TOKENS, ); @@ -86,11 +86,11 @@ export const useChatInputState = ({ [providers], ); - const onAddFile = useCallback((fileId: string) => { - setFiles((prevFiles) => [...prevFiles, fileId]); + const onAddFile = useCallback((file: components["schemas"]["ChatRsFile"]) => { + setFiles((prev) => [...prev, file]); }, []); const onRemoveFile = useCallback((fileId: string) => { - setFiles((prevFiles) => prevFiles.filter((id) => id !== fileId)); + setFiles((prev) => prev.filter((f) => f.id !== fileId)); }, []); const onRemoveAllFiles = useCallback(() => { setFiles([]); From 549923d0628064ef80c6bfd191b0f5f6de149fcf Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Mon, 1 Sep 2025 05:50:37 -0400 Subject: [PATCH 10/26] web: tweak tool manager to look like other pages --- web/src/components/ToolsManager.tsx | 56 +++++++++++++++-------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/web/src/components/ToolsManager.tsx b/web/src/components/ToolsManager.tsx index 1d480c7..edf96b8 100644 --- a/web/src/components/ToolsManager.tsx +++ b/web/src/components/ToolsManager.tsx @@ -1,4 +1,5 @@ import { + Check, CloudCog, Code2, FolderCog, @@ -228,19 +229,13 @@ export function ToolsManager({ return (
-
-

Tools (Beta)

-

+

+

Tools (Beta)

+

Manage your system tools and external API integrations

- - - Create New Tool @@ -264,16 +259,19 @@ export function ToolsManager({ Code Runner
- {!codeRunnerTool && ( - - )} + ) : ( + + )} + {!codeRunnerTool ? "Enable" : "Enabled"} + @@ -296,16 +294,20 @@ export function ToolsManager({ System / Time
- {!systemInfoTool && ( - - )} + ) : ( + + )} + {!systemInfoTool ? "Enable" : "Enabled"} + From 874755873db7411ece9d128ee26fd0532515247e Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Mon, 1 Sep 2025 05:54:35 -0400 Subject: [PATCH 11/26] web: tweak data fetching --- web/src/components/ToolsManager.tsx | 1 - web/src/components/chat/settings/ChatFileSelect.tsx | 2 +- web/src/lib/api/storage.ts | 3 ++- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/web/src/components/ToolsManager.tsx b/web/src/components/ToolsManager.tsx index edf96b8..0d300ec 100644 --- a/web/src/components/ToolsManager.tsx +++ b/web/src/components/ToolsManager.tsx @@ -41,7 +41,6 @@ import { DialogDescription, DialogHeader, DialogTitle, - DialogTrigger, } from "@/components/ui/dialog"; import { Select, diff --git a/web/src/components/chat/settings/ChatFileSelect.tsx b/web/src/components/chat/settings/ChatFileSelect.tsx index 23366d3..cc717fc 100644 --- a/web/src/components/chat/settings/ChatFileSelect.tsx +++ b/web/src/components/chat/settings/ChatFileSelect.tsx @@ -50,7 +50,7 @@ export default function ChatFileSelect({ onRemoveAllFiles, }: FileSelectionDialogProps) { const [open, setOpen] = useState(false); - const { data: files, isLoading } = useSessionFiles(sessionId); + const { data: files, isLoading } = useSessionFiles(sessionId, open); const fileList = useMemo(() => files || [], [files]); diff --git a/web/src/lib/api/storage.ts b/web/src/lib/api/storage.ts index ab2f25e..52ff58e 100644 --- a/web/src/lib/api/storage.ts +++ b/web/src/lib/api/storage.ts @@ -34,8 +34,9 @@ export const useUploadFile = () => { }); }; -export const useSessionFiles = (sessionId: string) => +export const useSessionFiles = (sessionId: string, enabled?: boolean) => useQuery({ + enabled, queryKey: ["files", { sessionId }], queryFn: async () => { const res = await client.GET("/storage/{session_id}", { From aedf1c6c4369ab25b6efb217a8d9c73996319084 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Mon, 1 Sep 2025 05:58:26 -0400 Subject: [PATCH 12/26] Update chat-bubble.tsx --- web/src/components/ui/chat/chat-bubble.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/components/ui/chat/chat-bubble.tsx b/web/src/components/ui/chat/chat-bubble.tsx index 59cf680..18145cd 100644 --- a/web/src/components/ui/chat/chat-bubble.tsx +++ b/web/src/components/ui/chat/chat-bubble.tsx @@ -7,7 +7,7 @@ import { Button } from "../button"; // ChatBubble const chatBubbleVariant = cva( - "flex flex-col-reverse md:flex-row gap-2 px-0.5 max-w-[100%] md:max-w-[80%] relative group", + "flex flex-col-reverse md:flex-row gap-2 px-0.5 first:mt-2 max-w-[100%] md:max-w-[80%] relative group", { variants: { variant: { From 5d1592f96c2d0b0d69afebf8630beb152ca37747 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Mon, 1 Sep 2025 17:00:40 -0400 Subject: [PATCH 13/26] Update guard.rs --- server/src/auth/guard.rs | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/server/src/auth/guard.rs b/server/src/auth/guard.rs index 2f3a20c..ebd3487 100644 --- a/server/src/auth/guard.rs +++ b/server/src/auth/guard.rs @@ -40,14 +40,6 @@ impl<'r> FromRequest<'r> for ChatRsUserId { type Error = &'r str; async fn from_request(req: &'r rocket::Request<'_>) -> Outcome { - // Try authentication via proxy headers if configured - if let Some(config) = req.rocket().state::() { - if let Some(proxy_user) = get_sso_user_from_headers(config, req.headers()) { - let mut db = try_outcome!(req.guard::().await); - return get_sso_auth_outcome(&proxy_user, config, &mut db).await; - } - }; - // Try authentication via API key if let Some(auth_header) = req.headers().get_one("Authorization") { let encryptor = req.rocket().state::().expect("should exist"); @@ -57,10 +49,21 @@ impl<'r> FromRequest<'r> for ChatRsUserId { // Try authentication via session let session = try_outcome!(req.guard::>().await); - match session.tap(|data| data.and_then(|auth_session| auth_session.user_id())) { - Some(user_id) => Outcome::Success(ChatRsUserId(user_id)), - None => Outcome::Error((Status::Unauthorized, "Unauthorized")), + if let Some(user_id) = + session.tap(|data| data.and_then(|auth_session| auth_session.user_id())) + { + return Outcome::Success(ChatRsUserId(user_id)); } + + // Try authentication via proxy headers if configured + if let Some(config) = req.rocket().state::() { + if let Some(proxy_user) = get_sso_user_from_headers(config, req.headers()) { + let mut db = try_outcome!(req.guard::().await); + return get_sso_auth_outcome(&proxy_user, config, &mut db).await; + } + }; + + Outcome::Error((Status::Unauthorized, "Unauthorized")) } } From 0dd302c83993fd4e6450123b3de4a429ba58a0a2 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Tue, 2 Sep 2025 00:12:19 -0400 Subject: [PATCH 14/26] server: support file input --- server/Cargo.lock | 1 + server/Cargo.toml | 1 + server/src/api/chat.rs | 65 ++++++----- server/src/db/models/chat.rs | 30 +++-- server/src/db/models/file.rs | 10 +- server/src/provider.rs | 116 +++++++++++++++++-- server/src/provider/anthropic.rs | 27 +---- server/src/provider/anthropic/request.rs | 139 ++++++++++++++--------- server/src/provider/lorem.rs | 10 +- server/src/provider/ollama.rs | 25 +--- server/src/provider/ollama/request.rs | 92 ++++++++------- server/src/provider/ollama/response.rs | 24 ++-- server/src/provider/openai.rs | 31 ++--- server/src/provider/openai/request.rs | 119 +++++++++++++------ server/src/storage.rs | 39 ++++++- server/src/storage/data_guard.rs | 2 +- server/src/storage/local.rs | 34 +++++- server/src/tools.rs | 44 +++---- server/src/tools/core.rs | 2 +- server/src/tools/system/files.rs | 2 +- 20 files changed, 516 insertions(+), 297 deletions(-) diff --git a/server/Cargo.lock b/server/Cargo.lock index 860b6fd..1f3a8a8 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -373,6 +373,7 @@ version = "0.6.0" dependencies = [ "aes-gcm", "astral-tokio-tar", + "base64 0.22.1", "bollard", "chrono", "const_format", diff --git a/server/Cargo.toml b/server/Cargo.toml index 3185807..35a060c 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -18,6 +18,7 @@ strip = true [dependencies] aes-gcm = "0.10.3" astral-tokio-tar = "0.5.2" +base64 = "0.22.1" bollard = { version = "0.19.1", features = ["ssl"] } chrono = { version = "0.4.41", features = ["serde"] } const_format = "0.2.34" diff --git a/server/src/api/chat.rs b/server/src/api/chat.rs index 580ddd2..aa1e94a 100644 --- a/server/src/api/chat.rs +++ b/server/src/api/chat.rs @@ -23,10 +23,11 @@ use crate::{ DbConnection, DbPool, }, errors::ApiError, - provider::{build_llm_provider_api, LlmError, LlmProviderOptions}, + provider::{build_llm_messages, build_llm_provider_api, LlmError, LlmProviderOptions}, redis::{ExclusiveRedisClient, RedisClient}, + storage::LocalStorage, stream::*, - tools::{get_llm_tools_from_input, SendChatToolInput}, + tools::SendChatToolInput, utils::{generate_title, Encryptor}, }; @@ -66,6 +67,8 @@ pub struct SendChatInput<'a> { options: LlmProviderOptions, /// Configuration of tools available to the assistant tools: Option, + /// IDs of the file(s) to attach to this message + files: Option>, } #[derive(JsonSchema, serde::Serialize)] @@ -86,6 +89,7 @@ pub async fn send_chat_stream( redis: RedisClient, redis_writer: ExclusiveRedisClient, encryptor: &State, + storage: &State, http_client: &State, session_id: Uuid, mut input: Json>, @@ -117,12 +121,26 @@ pub async fn send_chat_stream( // Get the user's chosen tools let mut tools = None; - if let Some(tool_input) = input.tools.as_ref() { - let mut tool_db_service = ToolDbService::new(&mut db); - tools = Some(get_llm_tools_from_input(&user_id, tool_input, &mut tool_db_service).await?); + if let Some(conf) = input.tools.take() { + let llm_tools = conf + .get_llm_tools(&user_id, &mut ToolDbService::new(&mut db)) + .await?; + tools = Some(llm_tools); + + // Update session metadata with new tools if needed + if session.meta.tool_config.as_ref().is_none_or(|c| *c != conf) { + let data = UpdateChatRsSession { + meta: Some(ChatRsSessionMeta::new(Some(conf))), + ..Default::default() + }; + ChatDbService::new(&mut db) + .update_session(&user_id, &session_id, data) + .await?; + } } // Generate session title if needed, and save user message to database + let attached_file_ids = input.files.take(); if let Some(user_message) = &input.message { if messages.is_empty() && session.title == DEFAULT_SESSION_TITLE { generate_title( @@ -134,47 +152,34 @@ pub async fn send_chat_stream( db_pool, ); } - let new_message = ChatDbService::new(&mut db) + let message_meta = attached_file_ids + .map(|ids| ChatRsMessageMeta::new_user(UserMeta { files: Some(ids) })) + .unwrap_or_default(); + let message = ChatDbService::new(&mut db) .save_message(NewChatRsMessage { content: user_message, session_id: &session_id, role: ChatRsMessageRole::User, - meta: ChatRsMessageMeta::default(), + meta: message_meta, }) .await?; - messages.push(new_message); - } - - // Update session metadata if needed - if let Some(tool_input) = input.tools.take() { - if session - .meta - .tool_config - .is_none_or(|config| config != tool_input) - { - let meta = ChatRsSessionMeta::new(Some(tool_input)); - let data = UpdateChatRsSession { - meta: Some(&meta), - ..Default::default() - }; - ChatDbService::new(&mut db) - .update_session(&user_id, &session_id, data) - .await?; - } + messages.push(message); } - // Get the provider's stream response + // Convert the messages, and get the provider's response stream + let llm_messages = + build_llm_messages(messages, &user_id, &session_id, &mut db, &storage).await?; let stream = provider_api - .chat_stream(messages, tools, &input.options) + .chat_stream(llm_messages, tools, &input.options) .await?; - let provider_id = input.provider_id; - let provider_options = input.options.clone(); // Create the Redis stream let mut stream_writer = LlmStreamWriter::new(redis_writer, &user_id, &session_id); stream_writer.start().await?; // Spawn a task to stream and save the response + let provider_id = input.provider_id.clone(); + let provider_options = input.options.clone(); tokio::spawn(async move { let (text, tool_calls, usage, errors, cancelled) = stream_writer.process(stream).await; let assistant_meta = AssistantMeta { diff --git a/server/src/db/models/chat.rs b/server/src/db/models/chat.rs index 21c8cd7..1b27574 100644 --- a/server/src/db/models/chat.rs +++ b/server/src/db/models/chat.rs @@ -11,7 +11,7 @@ use crate::{ tools::SendChatToolInput, }; -#[derive(Identifiable, Associations, Queryable, Selectable, JsonSchema, serde::Serialize)] +#[derive(Identifiable, Associations, Queryable, Selectable, JsonSchema, Serialize)] #[diesel(belongs_to(ChatRsUser, foreign_key = user_id))] #[diesel(table_name = super::schema::chat_sessions)] pub struct ChatRsSession { @@ -47,12 +47,12 @@ pub struct NewChatRsSession<'r> { #[diesel(table_name = super::schema::chat_sessions)] pub struct UpdateChatRsSession<'r> { pub title: Option<&'r str>, - pub meta: Option<&'r ChatRsSessionMeta>, + pub meta: Option, } #[derive(diesel_derive_enum::DbEnum)] #[db_enum(existing_type_path = "crate::db::schema::sql_types::ChatMessageRole")] -#[derive(Debug, PartialEq, Eq, JsonSchema, serde::Serialize)] +#[derive(Debug, PartialEq, Eq, JsonSchema, Serialize)] pub enum ChatRsMessageRole { User, Assistant, @@ -60,7 +60,7 @@ pub enum ChatRsMessageRole { Tool, } -#[derive(Identifiable, Queryable, Selectable, Associations, JsonSchema, serde::Serialize)] +#[derive(Identifiable, Queryable, Selectable, Associations, JsonSchema, Serialize)] #[diesel(belongs_to(ChatRsSession, foreign_key = session_id))] #[diesel(table_name = super::schema::chat_messages)] pub struct ChatRsMessage { @@ -74,6 +74,9 @@ pub struct ChatRsMessage { #[derive(Debug, Default, JsonSchema, Serialize, Deserialize, AsJsonb)] pub struct ChatRsMessageMeta { + /// User messages: metadata associated with the user message + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, /// Assistant messages: metadata associated with the assistant message #[serde(skip_serializing_if = "Option::is_none")] pub assistant: Option, @@ -82,14 +85,27 @@ pub struct ChatRsMessageMeta { pub tool_call: Option, } impl ChatRsMessageMeta { - pub fn new_assistant(assistant: AssistantMeta) -> Self { + pub fn new_assistant(assistant_meta: AssistantMeta) -> Self { + Self { + assistant: Some(assistant_meta), + ..Default::default() + } + } + pub fn new_user(user_meta: UserMeta) -> Self { Self { - assistant: Some(assistant), - tool_call: None, + user: Some(user_meta), + ..Default::default() } } } +#[derive(Debug, Default, JsonSchema, Serialize, Deserialize)] +pub struct UserMeta { + /// The IDs of the files attached to this message + #[serde(skip_serializing_if = "Option::is_none")] + pub files: Option>, +} + #[derive(Debug, Default, JsonSchema, Serialize, Deserialize)] pub struct AssistantMeta { /// The ID of the LLM provider used to generate this message diff --git a/server/src/db/models/file.rs b/server/src/db/models/file.rs index 4ef6b7a..72613be 100644 --- a/server/src/db/models/file.rs +++ b/server/src/db/models/file.rs @@ -3,7 +3,7 @@ use diesel::prelude::*; use schemars::JsonSchema; use uuid::Uuid; -use crate::db::models::ChatRsUser; +use crate::{db::models::ChatRsUser, provider::LlmError}; #[derive(Identifiable, Associations, Queryable, Selectable, JsonSchema, serde::Serialize)] #[diesel(belongs_to(ChatRsUser, foreign_key = user_id))] @@ -44,15 +44,15 @@ pub enum ChatRsFileType { Pdf, } -impl TryFrom<&'static str> for ChatRsFileType { - type Error = &'static str; +impl TryFrom<&str> for ChatRsFileType { + type Error = LlmError; - fn try_from(file_type: &'static str) -> Result { + fn try_from(file_type: &str) -> Result { match file_type { "text" => Ok(ChatRsFileType::Text), "image" => Ok(ChatRsFileType::Image), "pdf" => Ok(ChatRsFileType::Pdf), - _ => Err("Invalid file type"), + _ => Err(LlmError::InvalidFileType(file_type.to_owned())), } } } diff --git a/server/src/provider.rs b/server/src/provider.rs index b0b5818..5df2b2b 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -1,10 +1,10 @@ //! LLM providers API -pub mod anthropic; +mod anthropic; pub mod lorem; pub mod models; -pub mod ollama; -pub mod openai; +mod ollama; +mod openai; mod utils; use std::pin::Pin; @@ -15,11 +15,19 @@ use schemars::JsonSchema; use uuid::Uuid; use crate::{ - db::models::{ChatRsMessage, ChatRsProviderType, ChatRsToolCall}, + db::{ + models::{ + ChatRsFileType, ChatRsMessage, ChatRsMessageRole, ChatRsProviderType, ChatRsToolCall, + }, + services::FileDbService, + DbConnection, + }, + errors::ApiError, provider::{ anthropic::AnthropicProvider, lorem::LoremProvider, models::LlmModel, ollama::OllamaProvider, openai::OpenAIProvider, }, + storage::LocalStorage, }; pub const DEFAULT_MAX_TOKENS: u32 = 2000; @@ -52,6 +60,10 @@ pub enum LlmError { DecryptionError, #[error("Redis error: {0}")] Redis(#[from] fred::error::Error), + #[error("File error: {0}")] + Io(#[from] std::io::Error), + #[error("Invalid file type: {0}")] + InvalidFileType(String), } /// LLM errors during streaming @@ -71,10 +83,10 @@ pub enum LlmStreamError { Redis(#[from] fred::error::Error), } -/// Shared stream response type for LLM providers +/// Stream response type for LLM providers pub type LlmStream = Pin + Send>>; -/// Shared stream chunk result type for LLM providers +/// Stream chunk result type for LLM providers pub type LlmStreamChunkResult = Result; /// A streaming chunk of data from the LLM provider @@ -101,7 +113,7 @@ pub struct LlmUsage { pub cost: Option, } -/// Shared configuration for LLM provider requests +/// Configuration for LLM provider requests #[derive(Clone, Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] pub struct LlmProviderOptions { pub model: String, @@ -109,6 +121,31 @@ pub struct LlmProviderOptions { pub max_tokens: Option, } +/// Generic message type to send to LLM providers +pub enum LlmMessage { + User(LlmUserMessage), + Assistant(LlmAssistantMessage), + System(String), + Tool(LlmToolResult), +} + +pub struct LlmUserMessage { + text: String, + files: Option>, +} + +pub struct LlmFileInput { + pub name: String, + pub file_type: ChatRsFileType, + pub content_type: String, + pub content: String, +} + +pub struct LlmAssistantMessage { + text: String, + tool_calls: Option>, +} + /// Generic tool that can be passed to LLM providers #[derive(Debug)] pub struct LlmTool { @@ -129,13 +166,19 @@ pub enum LlmToolType { ExternalApi, } +pub struct LlmToolResult { + tool_call_id: String, + tool_name: String, + content: String, +} + /// Unified API for LLM providers #[async_trait] pub trait LlmApiProvider: Send + Sync + DynClone { /// Stream a chat response from the provider async fn chat_stream( &self, - messages: Vec, + messages: Vec, tools: Option>, options: &LlmProviderOptions, ) -> Result; @@ -175,3 +218,60 @@ pub fn build_llm_provider_api( ChatRsProviderType::Lorem => Ok(Box::new(LoremProvider::new())), } } + +/// Convert database messages to the generic messages to send to the provider implementation +pub async fn build_llm_messages( + messages: Vec, + user_id: &Uuid, + session_id: &Uuid, + db: &mut DbConnection, + storage: &LocalStorage, +) -> Result, ApiError> { + let mut llm_messages = Vec::with_capacity(messages.len()); + + for message in messages { + match message.role { + ChatRsMessageRole::User => { + let mut files: Option> = None; + if let Some(file_ids) = message.meta.user.and_then(|u| u.files) { + let mut file_db_service = FileDbService::new(db); + for file_id in file_ids { + let file = file_db_service + .find_session_file(user_id, session_id, &file_id) + .await?; + let (file_type, content) = + file.read_to_string(Some(session_id), storage).await?; + files.get_or_insert_default().push(LlmFileInput { + name: file.path, + content_type: file.content_type, + file_type, + content, + }); + } + } + llm_messages.push(LlmMessage::User(LlmUserMessage { + text: message.content, + files, + })) + } + ChatRsMessageRole::Assistant => { + llm_messages.push(LlmMessage::Assistant(LlmAssistantMessage { + text: message.content, + tool_calls: message.meta.assistant.and_then(|a| a.tool_calls), + })) + } + ChatRsMessageRole::System => llm_messages.push(LlmMessage::System(message.content)), + ChatRsMessageRole::Tool => { + if let Some(tool_call) = message.meta.tool_call { + llm_messages.push(LlmMessage::Tool(LlmToolResult { + tool_call_id: tool_call.id, + tool_name: tool_call.tool_name, + content: message.content, + })) + } + } + } + } + + Ok(llm_messages) +} diff --git a/server/src/provider/anthropic.rs b/server/src/provider/anthropic.rs index 1feda75..5be59da 100644 --- a/server/src/provider/anthropic.rs +++ b/server/src/provider/anthropic.rs @@ -5,23 +5,8 @@ mod response; use rocket::{async_stream, async_trait, futures::StreamExt}; -use crate::{ - db::models::ChatRsMessage, - provider::{ - models::{LlmModel, ModelsDevService, ModelsDevServiceProvider}, - utils::get_sse_events, - LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmTool, LlmUsage, - DEFAULT_MAX_TOKENS, - }, -}; - -use { - request::{ - build_anthropic_messages, build_anthropic_tools, AnthropicContentBlock, AnthropicMessage, - AnthropicRequest, - }, - response::{parse_anthropic_event, AnthropicResponse, AnthropicResponseContentBlock}, -}; +use crate::provider::*; +use {request::*, response::*}; const MESSAGES_API_URL: &str = "https://api.anthropic.com/v1/messages"; const API_VERSION: &str = "2023-06-01"; @@ -52,7 +37,7 @@ impl AnthropicProvider { impl LlmApiProvider for AnthropicProvider { async fn chat_stream( &self, - messages: Vec, + messages: Vec, tools: Option>, options: &LlmProviderOptions, ) -> Result { @@ -88,7 +73,7 @@ impl LlmApiProvider for AnthropicProvider { } let stream = async_stream::stream! { - let mut sse_event_stream = get_sse_events(response); + let mut sse_event_stream = utils::get_sse_events(response); let mut tool_calls = Vec::new(); while let Some(event_result) = sse_event_stream.next().await { match event_result { @@ -164,9 +149,9 @@ impl LlmApiProvider for AnthropicProvider { } async fn list_models(&self) -> Result, LlmError> { - let models_service = ModelsDevService::new(&self.redis, &self.client); + let models_service = models::ModelsDevService::new(&self.redis, &self.client); let models = models_service - .list_models(ModelsDevServiceProvider::Anthropic) + .list_models(models::ModelsDevServiceProvider::Anthropic) .await?; Ok(models) diff --git a/server/src/provider/anthropic/request.rs b/server/src/provider/anthropic/request.rs index 43af62b..2208580 100644 --- a/server/src/provider/anthropic/request.rs +++ b/server/src/provider/anthropic/request.rs @@ -2,71 +2,91 @@ use std::collections::HashMap; use serde::Serialize; -use crate::{ - db::models::{ChatRsMessage, ChatRsMessageRole}, - provider::LlmTool, -}; +use crate::provider::*; pub fn build_anthropic_messages<'a>( - messages: &'a [ChatRsMessage], + messages: &'a [LlmMessage], ) -> (Vec>, Option<&'a str>) { - let system_prompt = messages - .iter() - .rfind(|message| message.role == ChatRsMessageRole::System) - .map(|message| message.content.as_str()); + let system_prompt = messages.iter().rev().find_map(|message| { + let LlmMessage::System(msg) = message else { + return None; + }; + Some(msg.as_str()) + }); let anthropic_messages: Vec = messages .iter() .filter_map(|message| { - let role = match message.role { - ChatRsMessageRole::User => "user", - ChatRsMessageRole::Tool => "user", - ChatRsMessageRole::Assistant => "assistant", - ChatRsMessageRole::System => return None, - }; - let mut content_blocks = Vec::new(); - - // Handle tool result messages - if message.role == ChatRsMessageRole::Tool { - if let Some(executed_call) = &message.meta.tool_call { - content_blocks.push(AnthropicContentBlock::ToolResult { - tool_use_id: &executed_call.id, - content: &message.content, - }); - } - } else { - // Handle regular text content - if !message.content.is_empty() { - content_blocks.push(AnthropicContentBlock::Text { - text: &message.content, - }); + match message { + LlmMessage::User(user_message) => { + if !user_message.text.is_empty() { + content_blocks.push(AnthropicContentBlock::Text { + text: &user_message.text, + }); + } + if let Some(ref files) = user_message.files { + content_blocks.extend(files.iter().map(|file| match file.file_type { + ChatRsFileType::Text => AnthropicContentBlock::Document { + title: &file.name, + source: AnthropicSource::Text { + data: &file.content, + media_type: "text/plain", + }, + }, + ChatRsFileType::Image => AnthropicContentBlock::Image { + title: &file.name, + source: AnthropicSource::Base64 { + data: &file.content, + media_type: &file.content_type, + }, + }, + ChatRsFileType::Pdf => AnthropicContentBlock::Document { + title: &file.name, + source: AnthropicSource::Base64 { + data: &file.content, + media_type: "application/pdf", + }, + }, + })); + } + Some(AnthropicMessage { + role: "user", + content: content_blocks, + }) } - // Handle tool calls in assistant messages - if let Some(tool_calls) = message - .meta - .assistant - .as_ref() - .and_then(|a| a.tool_calls.as_ref()) - { - for tool_call in tool_calls { - content_blocks.push(AnthropicContentBlock::ToolUse { - id: &tool_call.id, - name: &tool_call.tool_name, - input: &tool_call.parameters, + LlmMessage::Assistant(assistant_message) => { + if !assistant_message.text.is_empty() { + content_blocks.push(AnthropicContentBlock::Text { + text: &assistant_message.text, }); } + if let Some(ref tool_calls) = assistant_message.tool_calls { + content_blocks.extend(tool_calls.iter().map(|tc| { + AnthropicContentBlock::ToolUse { + id: &tc.id, + name: &tc.tool_name, + input: &tc.parameters, + } + })); + } + Some(AnthropicMessage { + role: "assistant", + content: content_blocks, + }) } + LlmMessage::Tool(result) => { + content_blocks.push(AnthropicContentBlock::ToolResult { + tool_use_id: &result.tool_call_id, + content: &result.content, + }); + Some(AnthropicMessage { + role: "user", + content: content_blocks, + }) + } + _ => None, } - - if content_blocks.is_empty() { - return None; - } - - Some(AnthropicMessage { - role, - content: content_blocks, - }) }) .collect(); @@ -122,6 +142,14 @@ pub enum AnthropicContentBlock<'a> { Text { text: &'a str, }, + Image { + title: &'a str, + source: AnthropicSource<'a>, + }, + Document { + title: &'a str, + source: AnthropicSource<'a>, + }, ToolUse { id: &'a str, name: &'a str, @@ -132,3 +160,10 @@ pub enum AnthropicContentBlock<'a> { content: &'a str, }, } + +#[derive(Debug, Serialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum AnthropicSource<'a> { + Base64 { data: &'a str, media_type: &'a str }, + Text { data: &'a str, media_type: &'a str }, +} diff --git a/server/src/provider/lorem.rs b/server/src/provider/lorem.rs index 778ce9b..0e3f0f6 100644 --- a/server/src/provider/lorem.rs +++ b/server/src/provider/lorem.rs @@ -7,13 +7,7 @@ use rocket::futures::Stream; use rocket_okapi::JsonSchema; use tokio::time::{interval, Interval}; -use crate::{ - db::models::ChatRsMessage, - provider::{ - models::LlmModel, LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmStreamChunk, - LlmStreamChunkResult, LlmStreamError, LlmTool, - }, -}; +use crate::provider::*; /// A test/dummy provider that streams 'lorem ipsum...' and emits test errors during the stream #[derive(Debug, Clone)] @@ -71,7 +65,7 @@ impl Stream for LoremStream { impl LlmApiProvider for LoremProvider { async fn chat_stream( &self, - _messages: Vec, + _messages: Vec, _tools: Option>, _options: &LlmProviderOptions, ) -> Result { diff --git a/server/src/provider/ollama.rs b/server/src/provider/ollama.rs index 76f8d7c..6d2e4a5 100644 --- a/server/src/provider/ollama.rs +++ b/server/src/provider/ollama.rs @@ -5,23 +5,8 @@ mod response; use rocket::{async_stream, async_trait, futures::StreamExt}; -use crate::{ - db::models::ChatRsMessage, - provider::{ - models::LlmModel, utils::get_json_events, LlmApiProvider, LlmError, LlmProviderOptions, - LlmStream, LlmStreamChunk, LlmTool, LlmUsage, - }, -}; - -use { - request::{ - build_ollama_messages, build_ollama_tools, OllamaChatRequest, OllamaCompletionRequest, - OllamaOptions, - }, - response::{ - parse_ollama_event, OllamaCompletionResponse, OllamaModelsResponse, OllamaToolCall, - }, -}; +use crate::provider::*; +use {request::*, response::*}; const CHAT_API_URL: &str = "/api/chat"; const COMPLETION_API_URL: &str = "/api/generate"; @@ -47,7 +32,7 @@ impl OllamaProvider { impl LlmApiProvider for OllamaProvider { async fn chat_stream( &self, - messages: Vec, + messages: Vec, tools: Option>, options: &LlmProviderOptions, ) -> Result { @@ -84,8 +69,8 @@ impl LlmApiProvider for OllamaProvider { } let stream = async_stream::stream! { - let mut json_stream = get_json_events(response); - let mut tool_calls: Vec = Vec::new(); + let mut json_stream = utils::get_json_events(response); + let mut tool_calls: Vec = Vec::new(); while let Some(event) = json_stream.next().await { match event { Ok(event) => { diff --git a/server/src/provider/ollama/request.rs b/server/src/provider/ollama/request.rs index 520b41f..c685b39 100644 --- a/server/src/provider/ollama/request.rs +++ b/server/src/provider/ollama/request.rs @@ -3,59 +3,63 @@ use serde::Serialize; use crate::{ - db::models::{ChatRsMessage, ChatRsMessageRole}, - provider::LlmTool, + db::models::ChatRsFileType, + provider::{LlmMessage, LlmTool}, tools::ToolParameters, }; -/// Convert ChatRsMessages to Ollama messages -pub fn build_ollama_messages(messages: &[ChatRsMessage]) -> Vec { +/// Convert LlmMessages to Ollama messages +pub fn build_ollama_messages(messages: &[LlmMessage]) -> Vec { messages .iter() - .map(|msg| { - let role = match msg.role { - ChatRsMessageRole::User => "user", - ChatRsMessageRole::Assistant => "assistant", - ChatRsMessageRole::System => "system", - ChatRsMessageRole::Tool => "tool", - }; - - let mut ollama_msg = OllamaMessage { - role, - content: &msg.content, - tool_calls: None, - tool_name: None, - }; - - // Handle tool calls in assistant messages - if msg.role == ChatRsMessageRole::Assistant { - if let Some(msg_tool_calls) = msg - .meta - .assistant - .as_ref() - .and_then(|m| m.tool_calls.as_ref()) - { - let tool_calls = msg_tool_calls + .map(|message| match message { + LlmMessage::User(user_message) => { + let images = user_message.files.as_ref().map(|files| { + files + .iter() + .filter_map(|file| match file.file_type { + ChatRsFileType::Image => Some(file.content.as_str()), + _ => None, + }) + .collect::>() + }); + OllamaMessage { + role: "user", + content: &user_message.text, + images, + ..Default::default() + } + } + LlmMessage::Assistant(assistant_message) => { + let tool_calls = assistant_message.tool_calls.as_ref().map(|tool_calls| { + tool_calls .iter() .map(|tc| OllamaToolCall { - function: OllamaToolFunction { + function: OllamaFunction { name: &tc.tool_name, arguments: &tc.parameters, }, }) - .collect(); - ollama_msg.tool_calls = Some(tool_calls); + .collect() + }); + OllamaMessage { + role: "assistant", + content: &assistant_message.text, + tool_calls, + ..Default::default() } } - - // Handle tool messages (results from tool calls) - if msg.role == ChatRsMessageRole::Tool { - if let Some(ref tool_call) = msg.meta.tool_call { - ollama_msg.tool_name = Some(&tool_call.tool_name); - } - } - - ollama_msg + LlmMessage::System(text) => OllamaMessage { + role: "system", + content: text, + ..Default::default() + }, + LlmMessage::Tool(result) => OllamaMessage { + role: "tool", + content: &result.content, + tool_name: Some(&result.tool_name), + ..Default::default() + }, }) .collect() } @@ -100,11 +104,13 @@ pub struct OllamaCompletionRequest<'a> { } /// Ollama chat message -#[derive(Debug, Serialize)] +#[derive(Debug, Default, Serialize)] pub struct OllamaMessage<'a> { pub role: &'a str, pub content: &'a str, #[serde(skip_serializing_if = "Option::is_none")] + pub images: Option>, + #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>>, #[serde(skip_serializing_if = "Option::is_none")] pub tool_name: Option<&'a str>, @@ -113,12 +119,12 @@ pub struct OllamaMessage<'a> { /// Ollama tool call in a message #[derive(Debug, Serialize)] pub struct OllamaToolCall<'a> { - pub function: OllamaToolFunction<'a>, + pub function: OllamaFunction<'a>, } /// Ollama tool function #[derive(Debug, Serialize)] -pub struct OllamaToolFunction<'a> { +pub struct OllamaFunction<'a> { pub name: &'a str, pub arguments: &'a ToolParameters, } diff --git a/server/src/provider/ollama/response.rs b/server/src/provider/ollama/response.rs index 285b61e..b7c6610 100644 --- a/server/src/provider/ollama/response.rs +++ b/server/src/provider/ollama/response.rs @@ -9,8 +9,8 @@ use crate::{ /// Parse Ollama streaming event into LlmStreamChunks, and track tool calls pub fn parse_ollama_event( - event: OllamaStreamResponse, - tool_calls: &mut Vec, + event: OllamaStreamEvent, + tool_calls: &mut Vec, ) -> Vec> { let mut chunks = Vec::with_capacity(1); // Handle final message with usage stats @@ -42,10 +42,10 @@ pub fn parse_ollama_event( /// Ollama chat response (streaming) #[derive(Debug, Deserialize)] -pub struct OllamaStreamResponse { +pub struct OllamaStreamEvent { pub model: String, pub created_at: String, - pub message: OllamaMessage, + pub message: OllamaMessageResponse, pub done: bool, #[serde(default)] pub done_reason: Option, @@ -88,28 +88,28 @@ pub struct OllamaCompletionResponse { /// Ollama message in response #[derive(Debug, Deserialize)] -pub struct OllamaMessage { +pub struct OllamaMessageResponse { pub role: String, #[serde(default)] pub content: String, #[serde(default)] - pub tool_calls: Vec, + pub tool_calls: Vec, } /// Ollama tool call in response #[derive(Debug, Deserialize)] -pub struct OllamaToolCall { - pub function: OllamaToolFunction, +pub struct OllamaToolCallResponse { + pub function: OllamaFunctionResponse, } /// Ollama tool function in response #[derive(Debug, Deserialize)] -pub struct OllamaToolFunction { +pub struct OllamaFunctionResponse { pub name: String, pub arguments: serde_json::Value, } -impl OllamaToolFunction { +impl OllamaFunctionResponse { /// Convert to ChatRsToolCall if the tool exists in the provided tools pub fn convert(self, tools: &[LlmTool]) -> Option { let tool = tools.iter().find(|t| t.name == self.name)?; @@ -140,8 +140,8 @@ impl From<&OllamaCompletionResponse> for Option { } } -impl From<&OllamaStreamResponse> for Option { - fn from(response: &OllamaStreamResponse) -> Self { +impl From<&OllamaStreamEvent> for Option { + fn from(response: &OllamaStreamEvent) -> Self { if response.prompt_eval_count.is_some() || response.eval_count.is_some() { Some(LlmUsage { input_tokens: response.prompt_eval_count, diff --git a/server/src/provider/openai.rs b/server/src/provider/openai.rs index 8772f94..366cace 100644 --- a/server/src/provider/openai.rs +++ b/server/src/provider/openai.rs @@ -5,22 +5,8 @@ mod response; use rocket::{async_stream, async_trait, futures::StreamExt}; -use crate::{ - db::models::ChatRsMessage, - provider::{ - models::{LlmModel, ModelsDevService, ModelsDevServiceProvider}, - utils::get_sse_events, - LlmApiProvider, LlmError, LlmProviderOptions, LlmStream, LlmStreamChunk, LlmTool, LlmUsage, - }, -}; - -use { - request::{ - build_openai_messages, build_openai_tools, OpenAIMessage, OpenAIRequest, - OpenAIStreamOptions, - }, - response::{parse_openai_event, OpenAIResponse, OpenAIStreamToolCall}, -}; +use crate::provider::*; +use {request::*, response::*}; const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1"; const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1"; @@ -54,7 +40,7 @@ impl OpenAIProvider { impl LlmApiProvider for OpenAIProvider { async fn chat_stream( &self, - messages: Vec, + messages: Vec, tools: Option>, options: &LlmProviderOptions, ) -> Result { @@ -99,7 +85,7 @@ impl LlmApiProvider for OpenAIProvider { } let stream = async_stream::stream! { - let mut sse_event_stream = get_sse_events(response); + let mut sse_event_stream =utils:: get_sse_events(response); let mut tool_calls: Vec = Vec::new(); while let Some(event) = sse_event_stream.next().await { match event { @@ -134,7 +120,7 @@ impl LlmApiProvider for OpenAIProvider { model: &options.model, messages: vec![OpenAIMessage { role: "user", - content: Some(message), + content: Some(vec![OpenAIContent::Text { text: message }]), ..Default::default() }], max_tokens: options.max_tokens, @@ -182,12 +168,11 @@ impl LlmApiProvider for OpenAIProvider { } async fn list_models(&self) -> Result, LlmError> { - let models_service = ModelsDevService::new(&self.redis, &self.client); - let models = models_service + let models = models::ModelsDevService::new(&self.redis, &self.client) .list_models({ match self.base_url.as_str() { - OPENROUTER_API_BASE_URL => ModelsDevServiceProvider::OpenRouter, - _ => ModelsDevServiceProvider::OpenAI, + OPENROUTER_API_BASE_URL => models::ModelsDevServiceProvider::OpenRouter, + _ => models::ModelsDevServiceProvider::OpenAI, } }) .await?; diff --git a/server/src/provider/openai/request.rs b/server/src/provider/openai/request.rs index 9feafcf..5fcb774 100644 --- a/server/src/provider/openai/request.rs +++ b/server/src/provider/openai/request.rs @@ -1,45 +1,77 @@ use serde::Serialize; use crate::{ - db::models::{ChatRsMessage, ChatRsMessageRole}, - provider::LlmTool, + db::models::ChatRsFileType, + provider::{LlmMessage, LlmTool}, }; -pub fn build_openai_messages<'a>(messages: &'a [ChatRsMessage]) -> Vec> { +pub fn build_openai_messages<'a>(messages: &'a [LlmMessage]) -> Vec> { messages .iter() - .map(|message| { - let role = match message.role { - ChatRsMessageRole::User => "user", - ChatRsMessageRole::Assistant => "assistant", - ChatRsMessageRole::System => "system", - ChatRsMessageRole::Tool => "tool", - }; - let openai_message = OpenAIMessage { - role, - content: Some(&message.content), - tool_call_id: message.meta.tool_call.as_ref().map(|tc| tc.id.as_str()), - tool_calls: message - .meta - .assistant - .as_ref() - .and_then(|meta| meta.tool_calls.as_ref()) - .map(|tc| { - tc.iter() - .map(|tc| OpenAIToolCall { - id: &tc.id, - tool_type: "function", - function: OpenAIToolCallFunction { - name: &tc.tool_name, - arguments: serde_json::to_string(&tc.parameters) - .unwrap_or_default(), - }, - }) - .collect() + .map(|message| match message { + LlmMessage::User(user_message) => { + let mut content = Vec::new(); + if !user_message.text.is_empty() { + content.push(OpenAIContent::Text { + text: &user_message.text, + }); + } + if let Some(ref files) = user_message.files { + content.extend(files.iter().map(|file| match file.file_type { + ChatRsFileType::Text => OpenAIContent::Text { + text: &file.content, + }, + ChatRsFileType::Image => OpenAIContent::ImageUrl { url: &file.content }, + ChatRsFileType::Pdf => OpenAIContent::File { + file_data: &file.content, + filename: &file.name, + }, + })); + } + OpenAIMessage { + role: "user", + content: Some(content), + ..Default::default() + } + } + LlmMessage::Assistant(assistant_message) => { + let tool_calls = assistant_message.tool_calls.as_ref().map(|tc| { + tc.iter() + .map(|tc| OpenAIToolCall { + id: &tc.id, + tool_type: "function", + function: OpenAIToolCallFunction { + name: &tc.tool_name, + arguments: serde_json::to_string(&tc.parameters) + .unwrap_or_default(), + }, + }) + .collect() + }); + OpenAIMessage { + role: "assistant", + content: (!assistant_message.text.is_empty()).then(|| { + vec![OpenAIContent::Text { + text: &assistant_message.text, + }] }), - }; - - openai_message + tool_calls, + ..Default::default() + } + } + LlmMessage::System(text) => OpenAIMessage { + role: "system", + content: Some(vec![OpenAIContent::Text { text }]), + ..Default::default() + }, + LlmMessage::Tool(tool_result) => OpenAIMessage { + role: "tool", + content: Some(vec![OpenAIContent::Text { + text: &tool_result.content, + }]), + tool_call_id: Some(&tool_result.tool_call_id), + ..Default::default() + }, }) .collect() } @@ -90,13 +122,28 @@ pub struct OpenAIStreamOptions { pub struct OpenAIMessage<'a> { pub role: &'a str, #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option<&'a str>, + pub content: Option>>, #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option<&'a str>, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>>, } +#[derive(Debug, Serialize)] +#[serde(tag = "type", rename = "snakes_case")] +pub enum OpenAIContent<'a> { + Text { + text: &'a str, + }, + ImageUrl { + url: &'a str, + }, + File { + file_data: &'a str, + filename: &'a str, + }, +} + /// OpenAI tool definition #[derive(Debug, Serialize)] pub struct OpenAITool<'a> { @@ -110,8 +157,8 @@ pub struct OpenAITool<'a> { pub struct OpenAIToolFunction<'a> { name: &'a str, description: &'a str, - parameters: &'a serde_json::Value, strict: bool, + parameters: &'a serde_json::Value, } /// OpenAI tool call in messages diff --git a/server/src/storage.rs b/server/src/storage.rs index fb6b547..000cfd5 100644 --- a/server/src/storage.rs +++ b/server/src/storage.rs @@ -1,14 +1,18 @@ mod data_guard; mod local; -pub use data_guard::*; -pub use local::*; - -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use rocket::fairing::AdHoc; +use uuid::Uuid; -use crate::config::get_app_config; +use crate::{ + config::get_app_config, + db::models::{ChatRsFile, ChatRsFileType}, + provider::LlmError, +}; +pub use data_guard::*; +pub use local::*; /// Default data directory path. pub const DEFAULT_DATA_DIR: &str = "/data"; @@ -24,3 +28,28 @@ pub fn setup_storage() -> AdHoc { rocket.manage(storage) }) } + +impl ChatRsFile { + /// Get the file type and contents for LLM input. Uses base64 encoding for image and PDF files. + pub async fn read_to_string( + &self, + session_id: Option<&Uuid>, + storage: &LocalStorage, + ) -> Result<(ChatRsFileType, String), LlmError> { + let file_type = ChatRsFileType::try_from(self.file_type.as_str())?; + let content: String = match file_type { + ChatRsFileType::Text => { + let bytes = storage + .read_file_as_bytes(&self.user_id, session_id, Path::new(&self.path)) + .await?; + String::from_utf8_lossy(&bytes).into() + } + ChatRsFileType::Image | ChatRsFileType::Pdf => { + storage + .read_file_as_base64(&self.user_id, session_id, Path::new(&self.path)) + .await? + } + }; + Ok((file_type, content)) + } +} diff --git a/server/src/storage/data_guard.rs b/server/src/storage/data_guard.rs index 1d84bd1..25e4cce 100644 --- a/server/src/storage/data_guard.rs +++ b/server/src/storage/data_guard.rs @@ -46,7 +46,7 @@ impl<'r> FromData<'r> for FileData<'r> { if content_type.is_jpeg() || content_type.is_png() || content_type.is_webp() - || content_type.is_bmp() + || content_type.is_gif() { ChatRsFileType::Image } else if content_type.is_pdf() { diff --git a/server/src/storage/local.rs b/server/src/storage/local.rs index 1ac74f8..e99bd80 100644 --- a/server/src/storage/local.rs +++ b/server/src/storage/local.rs @@ -17,7 +17,7 @@ impl LocalStorage { LocalStorage { base_path } } - pub async fn read_file( + pub async fn read_file_as_bytes( &self, user_id: &Uuid, session_id: Option<&Uuid>, @@ -26,14 +26,23 @@ impl LocalStorage { let path = self.get_file_path(user_id, session_id, path)?; let mut file = File::open(path).await?; let metadata = file.metadata().await?; + let mut file_reader = BufReader::new(&mut file); let mut buffer = Vec::with_capacity(metadata.len() as usize); - let mut file_reader = BufReader::new(&mut file); file_reader.read_to_end(&mut buffer).await?; - Ok(buffer) } + pub async fn read_file_as_base64( + &self, + user_id: &Uuid, + session_id: Option<&Uuid>, + path: &Path, + ) -> IoResult { + let path = self.get_file_path(user_id, session_id, path)?; + tokio::task::spawn_blocking(move || read_base64(&path)).await? + } + pub async fn create_file( &self, user_id: &Uuid, @@ -103,3 +112,22 @@ impl LocalStorage { Ok(self.get_user_directory(user_id, session_id).join(path)) } } + +/// Reads a file as a base64 encoded string (synchronous because `base64` crate writer is synchronous). +fn read_base64(path: &Path) -> IoResult { + let mut file = std::fs::File::open(path)?; + let file_size = file.metadata()?.len(); + let estimated_size = (file_size + 2) / 3 * 4; + let mut file_reader = std::io::BufReader::new(&mut file); + + let mut result = Vec::with_capacity(estimated_size as usize); + { + let mut encoder = base64::write::EncoderWriter::new( + &mut result, + &base64::engine::general_purpose::STANDARD, + ); + std::io::copy(&mut file_reader, &mut encoder)?; + encoder.finish()?; + } + Ok(String::from_utf8(result).expect("base64 is valid UTF8")) +} diff --git a/server/src/tools.rs b/server/src/tools.rs index 7d81cab..598e15e 100644 --- a/server/src/tools.rs +++ b/server/src/tools.rs @@ -22,27 +22,29 @@ pub struct SendChatToolInput { pub external_apis: Option>, } -/// Get all tools from the user's input in LLM generic format -pub async fn get_llm_tools_from_input( - user_id: &Uuid, - input: &SendChatToolInput, - tool_db_service: &mut ToolDbService<'_>, -) -> Result, ApiError> { - let mut llm_tools = Vec::with_capacity(5); - if let Some(ref system_tool_input) = input.system { - let system_tools = tool_db_service.find_system_tools_by_user(&user_id).await?; - let system_llm_tools = system_tool_input.get_llm_tools(&system_tools)?; - llm_tools.extend(system_llm_tools); - } - if let Some(ref external_apis_input) = input.external_apis { - let external_api_tools = tool_db_service - .find_external_api_tools_by_user(&user_id) - .await?; - for tool_input in external_apis_input { - let api_llm_tools = tool_input.into_llm_tools(&external_api_tools)?; - llm_tools.extend(api_llm_tools); +impl SendChatToolInput { + /// Get all tools from the user's input in LLM generic format + pub async fn get_llm_tools( + &self, + user_id: &Uuid, + tool_db_service: &mut ToolDbService<'_>, + ) -> Result, ApiError> { + let mut llm_tools = Vec::with_capacity(5); + if let Some(ref system_tool_input) = self.system { + let system_tools = tool_db_service.find_system_tools_by_user(&user_id).await?; + let system_llm_tools = system_tool_input.get_llm_tools(&system_tools)?; + llm_tools.extend(system_llm_tools); + } + if let Some(ref external_apis_input) = self.external_apis { + let external_api_tools = tool_db_service + .find_external_api_tools_by_user(&user_id) + .await?; + for tool_input in external_apis_input { + let api_llm_tools = tool_input.into_llm_tools(&external_api_tools)?; + llm_tools.extend(api_llm_tools); + } } - } - Ok(llm_tools) + Ok(llm_tools) + } } diff --git a/server/src/tools/core.rs b/server/src/tools/core.rs index e4a60d9..8cacf92 100644 --- a/server/src/tools/core.rs +++ b/server/src/tools/core.rs @@ -62,7 +62,7 @@ pub enum ToolError { ToolExecutionError(String), #[error("Tool execution cancelled: {0}")] Cancelled(String), - #[error("IO error: {0}")] + #[error("File error: {0}")] Io(#[from] std::io::Error), #[error("Database error: {0}")] Database(#[from] diesel::result::Error), diff --git a/server/src/tools/system/files.rs b/server/src/tools/system/files.rs index b587ba2..9d38961 100644 --- a/server/src/tools/system/files.rs +++ b/server/src/tools/system/files.rs @@ -189,7 +189,7 @@ impl SystemTool for Files<'_> { let input: ReadFileInput = serde_json::from_value(parameters)?; let path = Path::new(&input.path); let content_bytes = storage - .read_file(self.user_id, Some(self.session_id), path) + .read_file_as_bytes(self.user_id, Some(self.session_id), path) .await?; let content = String::from_utf8_lossy(&content_bytes); Ok((content.into(), ToolResponseFormat::Text)) From 7b1237704aa854472460e51f3515b604cc566fd6 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Tue, 2 Sep 2025 00:37:51 -0400 Subject: [PATCH 15/26] server: organize providers module --- server/src/provider.rs | 185 +----------------- server/src/provider/core.rs | 171 ++++++++++++++++ server/src/provider/models.rs | 2 + server/src/provider/providers.rs | 11 ++ .../src/provider/{ => providers}/anthropic.rs | 0 .../{ => providers}/anthropic/request.rs | 0 .../{ => providers}/anthropic/response.rs | 0 server/src/provider/{ => providers}/lorem.rs | 12 +- server/src/provider/{ => providers}/ollama.rs | 0 .../{ => providers}/ollama/request.rs | 0 .../{ => providers}/ollama/response.rs | 0 server/src/provider/{ => providers}/openai.rs | 0 .../{ => providers}/openai/request.rs | 0 .../{ => providers}/openai/response.rs | 0 server/src/provider/utils.rs | 2 + server/src/storage/local.rs | 3 +- server/src/stream/llm_writer.rs | 2 +- 17 files changed, 199 insertions(+), 189 deletions(-) create mode 100644 server/src/provider/core.rs create mode 100644 server/src/provider/providers.rs rename server/src/provider/{ => providers}/anthropic.rs (100%) rename server/src/provider/{ => providers}/anthropic/request.rs (100%) rename server/src/provider/{ => providers}/anthropic/response.rs (100%) rename server/src/provider/{ => providers}/lorem.rs (91%) rename server/src/provider/{ => providers}/ollama.rs (100%) rename server/src/provider/{ => providers}/ollama/request.rs (100%) rename server/src/provider/{ => providers}/ollama/response.rs (100%) rename server/src/provider/{ => providers}/openai.rs (100%) rename server/src/provider/{ => providers}/openai/request.rs (100%) rename server/src/provider/{ => providers}/openai/response.rs (100%) diff --git a/server/src/provider.rs b/server/src/provider.rs index 5df2b2b..14e2e77 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -1,196 +1,27 @@ -//! LLM providers API +//! LLM providers module -mod anthropic; -pub mod lorem; +use uuid::Uuid; + +mod core; +pub use core::*; pub mod models; -mod ollama; -mod openai; +pub mod providers; mod utils; -use std::pin::Pin; - -use dyn_clone::DynClone; -use rocket::{async_trait, futures::Stream}; -use schemars::JsonSchema; -use uuid::Uuid; - use crate::{ db::{ - models::{ - ChatRsFileType, ChatRsMessage, ChatRsMessageRole, ChatRsProviderType, ChatRsToolCall, - }, + models::{ChatRsFileType, ChatRsMessage, ChatRsMessageRole, ChatRsProviderType}, services::FileDbService, DbConnection, }, errors::ApiError, - provider::{ - anthropic::AnthropicProvider, lorem::LoremProvider, models::LlmModel, - ollama::OllamaProvider, openai::OpenAIProvider, - }, + provider::{models::LlmModel, providers::*}, storage::LocalStorage, }; pub const DEFAULT_MAX_TOKENS: u32 = 2000; pub const DEFAULT_TEMPERATURE: f32 = 0.7; -/// LLM provider-related errors -#[derive(Debug, thiserror::Error)] -pub enum LlmError { - #[error("Missing API key")] - MissingApiKey, - #[error("Provider error: {0}")] - ProviderError(String), - #[error("models.dev error: {0}")] - ModelsDevError(String), - #[error("No chat response")] - NoResponse, - #[error("Unsupported provider")] - UnsupportedProvider, - #[error("Already streaming a response for this session")] - AlreadyStreaming, - #[error("No stream found, or the stream was cancelled")] - StreamNotFound, - #[error("Missing event in stream")] - NoStreamEvent, - #[error("Client disconnected")] - ClientDisconnected, - #[error("Encryption error")] - EncryptionError, - #[error("Decryption error")] - DecryptionError, - #[error("Redis error: {0}")] - Redis(#[from] fred::error::Error), - #[error("File error: {0}")] - Io(#[from] std::io::Error), - #[error("Invalid file type: {0}")] - InvalidFileType(String), -} - -/// LLM errors during streaming -#[derive(Debug, thiserror::Error)] -pub enum LlmStreamError { - #[error("Provider error: {0}")] - ProviderError(String), - #[error("Failed to parse event: {0}")] - Parsing(#[from] serde_json::Error), - #[error("Failed to decode response: {0}")] - Decoding(#[from] tokio_util::codec::LinesCodecError), - #[error("Timeout waiting for provider response")] - StreamTimeout, - #[error("Stream was cancelled")] - StreamCancelled, - #[error("Redis error: {0}")] - Redis(#[from] fred::error::Error), -} - -/// Stream response type for LLM providers -pub type LlmStream = Pin + Send>>; - -/// Stream chunk result type for LLM providers -pub type LlmStreamChunkResult = Result; - -/// A streaming chunk of data from the LLM provider -pub enum LlmStreamChunk { - Text(String), - ToolCalls(Vec), - PendingToolCall(LlmPendingToolCall), - Usage(LlmUsage), -} - -#[derive(Debug, Clone, serde::Serialize)] -pub struct LlmPendingToolCall { - pub index: usize, - pub tool_name: String, -} - -/// Usage stats from the LLM provider -#[derive(Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] -pub struct LlmUsage { - pub input_tokens: Option, - pub output_tokens: Option, - /// Only included by OpenRouter - #[serde(skip_serializing_if = "Option::is_none")] - pub cost: Option, -} - -/// Configuration for LLM provider requests -#[derive(Clone, Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] -pub struct LlmProviderOptions { - pub model: String, - pub temperature: Option, - pub max_tokens: Option, -} - -/// Generic message type to send to LLM providers -pub enum LlmMessage { - User(LlmUserMessage), - Assistant(LlmAssistantMessage), - System(String), - Tool(LlmToolResult), -} - -pub struct LlmUserMessage { - text: String, - files: Option>, -} - -pub struct LlmFileInput { - pub name: String, - pub file_type: ChatRsFileType, - pub content_type: String, - pub content: String, -} - -pub struct LlmAssistantMessage { - text: String, - tool_calls: Option>, -} - -/// Generic tool that can be passed to LLM providers -#[derive(Debug)] -pub struct LlmTool { - pub name: String, - pub description: String, - pub input_schema: serde_json::Value, - /// ID of the RsChat tool that this is derived from - pub tool_id: Uuid, - /// The type of tool this is derived from (internal, external API, etc.) - pub tool_type: LlmToolType, -} - -#[derive(Default, Debug, Clone, Copy, JsonSchema, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum LlmToolType { - #[default] - System, - ExternalApi, -} - -pub struct LlmToolResult { - tool_call_id: String, - tool_name: String, - content: String, -} - -/// Unified API for LLM providers -#[async_trait] -pub trait LlmApiProvider: Send + Sync + DynClone { - /// Stream a chat response from the provider - async fn chat_stream( - &self, - messages: Vec, - tools: Option>, - options: &LlmProviderOptions, - ) -> Result; - - /// Submit a prompt to the provider (not streamed) - async fn prompt(&self, message: &str, options: &LlmProviderOptions) - -> Result; - - /// List available models from the provider - async fn list_models(&self) -> Result, LlmError>; -} - /// Build the LLM API to make calls to the provider pub fn build_llm_provider_api( provider_type: &ChatRsProviderType, diff --git a/server/src/provider/core.rs b/server/src/provider/core.rs new file mode 100644 index 0000000..83061cd --- /dev/null +++ b/server/src/provider/core.rs @@ -0,0 +1,171 @@ +//! LLM providers - core structs and types + +use std::pin::Pin; + +use dyn_clone::DynClone; +use rocket::{async_trait, futures::Stream}; +use schemars::JsonSchema; +use uuid::Uuid; + +use crate::{ + db::models::{ChatRsFileType, ChatRsToolCall}, + provider::models::LlmModel, +}; + +/// Unified API for LLM providers +#[async_trait] +pub trait LlmApiProvider: Send + Sync + DynClone { + /// Stream a chat response from the provider + async fn chat_stream( + &self, + messages: Vec, + tools: Option>, + options: &LlmProviderOptions, + ) -> Result; + + /// Submit a prompt to the provider (not streamed) + async fn prompt(&self, message: &str, options: &LlmProviderOptions) + -> Result; + + /// List available models from the provider + async fn list_models(&self) -> Result, LlmError>; +} + +/// Stream response type for LLM providers +pub type LlmStream = Pin + Send>>; + +/// Stream chunk result type for LLM providers +pub type LlmStreamChunkResult = Result; + +/// A streaming chunk of data from the LLM provider +pub enum LlmStreamChunk { + Text(String), + ToolCalls(Vec), + PendingToolCall(LlmPendingToolCall), + Usage(LlmUsage), +} + +/// LLM provider-related errors +#[derive(Debug, thiserror::Error)] +pub enum LlmError { + #[error("Missing API key")] + MissingApiKey, + #[error("Provider error: {0}")] + ProviderError(String), + #[error("models.dev error: {0}")] + ModelsDevError(String), + #[error("No chat response")] + NoResponse, + #[error("Unsupported provider")] + UnsupportedProvider, + #[error("Already streaming a response for this session")] + AlreadyStreaming, + #[error("No stream found, or the stream was cancelled")] + StreamNotFound, + #[error("Missing event in stream")] + NoStreamEvent, + #[error("Client disconnected")] + ClientDisconnected, + #[error("Encryption error")] + EncryptionError, + #[error("Decryption error")] + DecryptionError, + #[error("Redis error: {0}")] + Redis(#[from] fred::error::Error), + #[error("File error: {0}")] + Io(#[from] std::io::Error), + #[error("Invalid file type: {0}")] + InvalidFileType(String), +} + +/// LLM errors that can occur during streaming +#[derive(Debug, thiserror::Error)] +pub enum LlmStreamError { + #[error("Provider error: {0}")] + ProviderError(String), + #[error("Failed to parse event: {0}")] + Parsing(#[from] serde_json::Error), + #[error("Failed to decode response: {0}")] + Decoding(#[from] tokio_util::codec::LinesCodecError), + #[error("Timeout waiting for provider response")] + StreamTimeout, + #[error("Stream was cancelled")] + StreamCancelled, + #[error("Redis error: {0}")] + Redis(#[from] fred::error::Error), +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct LlmPendingToolCall { + pub index: usize, + pub tool_name: String, +} + +/// Usage stats from the LLM provider +#[derive(Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] +pub struct LlmUsage { + pub input_tokens: Option, + pub output_tokens: Option, + /// Only included by OpenRouter + #[serde(skip_serializing_if = "Option::is_none")] + pub cost: Option, +} + +/// Configuration for LLM provider requests +#[derive(Clone, Debug, Default, JsonSchema, serde::Serialize, serde::Deserialize)] +pub struct LlmProviderOptions { + pub model: String, + pub temperature: Option, + pub max_tokens: Option, +} + +/// Generic message type to send to LLM providers +pub enum LlmMessage { + User(LlmUserMessage), + Assistant(LlmAssistantMessage), + System(String), + Tool(LlmToolResult), +} + +pub struct LlmUserMessage { + pub text: String, + pub files: Option>, +} + +pub struct LlmFileInput { + pub name: String, + pub file_type: ChatRsFileType, + pub content_type: String, + pub content: String, +} + +pub struct LlmAssistantMessage { + pub text: String, + pub tool_calls: Option>, +} + +/// Generic tool that can be passed to LLM providers +#[derive(Debug)] +pub struct LlmTool { + pub name: String, + pub description: String, + pub input_schema: serde_json::Value, + /// ID of the RsChat tool that this is derived from + pub tool_id: Uuid, + /// The type of tool this is derived from (internal, external API, etc.) + pub tool_type: LlmToolType, +} + +#[derive(Default, Debug, Clone, Copy, JsonSchema, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum LlmToolType { + #[default] + System, + ExternalApi, +} + +pub struct LlmToolResult { + pub tool_call_id: String, + pub tool_name: String, + pub content: String, +} diff --git a/server/src/provider/models.rs b/server/src/provider/models.rs index 6ed5bc8..f527ac1 100644 --- a/server/src/provider/models.rs +++ b/server/src/provider/models.rs @@ -1,3 +1,5 @@ +//! LLM model structs and utils + use std::collections::HashMap; use enum_iterator::{all, Sequence}; diff --git a/server/src/provider/providers.rs b/server/src/provider/providers.rs new file mode 100644 index 0000000..d558d49 --- /dev/null +++ b/server/src/provider/providers.rs @@ -0,0 +1,11 @@ +//! LLM provider implementations + +mod anthropic; +mod lorem; +mod ollama; +mod openai; + +pub use anthropic::*; +pub use lorem::*; +pub use ollama::*; +pub use openai::*; diff --git a/server/src/provider/anthropic.rs b/server/src/provider/providers/anthropic.rs similarity index 100% rename from server/src/provider/anthropic.rs rename to server/src/provider/providers/anthropic.rs diff --git a/server/src/provider/anthropic/request.rs b/server/src/provider/providers/anthropic/request.rs similarity index 100% rename from server/src/provider/anthropic/request.rs rename to server/src/provider/providers/anthropic/request.rs diff --git a/server/src/provider/anthropic/response.rs b/server/src/provider/providers/anthropic/response.rs similarity index 100% rename from server/src/provider/anthropic/response.rs rename to server/src/provider/providers/anthropic/response.rs diff --git a/server/src/provider/lorem.rs b/server/src/provider/providers/lorem.rs similarity index 91% rename from server/src/provider/lorem.rs rename to server/src/provider/providers/lorem.rs index 0e3f0f6..d93535a 100644 --- a/server/src/provider/lorem.rs +++ b/server/src/provider/providers/lorem.rs @@ -4,7 +4,6 @@ use std::pin::Pin; use std::time::Duration; use rocket::futures::Stream; -use rocket_okapi::JsonSchema; use tokio::time::{interval, Interval}; use crate::provider::*; @@ -12,19 +11,12 @@ use crate::provider::*; /// A test/dummy provider that streams 'lorem ipsum...' and emits test errors during the stream #[derive(Debug, Clone)] pub struct LoremProvider { - pub config: LoremConfig, -} - -#[derive(Debug, Clone, JsonSchema)] -pub struct LoremConfig { pub interval: u32, } impl LoremProvider { pub fn new() -> Self { - LoremProvider { - config: LoremConfig { interval: 400 }, - } + LoremProvider { interval: 400 } } } @@ -101,7 +93,7 @@ impl LlmApiProvider for LoremProvider { let stream: LlmStream = Box::pin(LoremStream { words: lorem_words, index: 0, - interval: interval(Duration::from_millis(self.config.interval.into())), + interval: interval(Duration::from_millis(self.interval.into())), }); tokio::time::sleep(Duration::from_millis(1000)).await; diff --git a/server/src/provider/ollama.rs b/server/src/provider/providers/ollama.rs similarity index 100% rename from server/src/provider/ollama.rs rename to server/src/provider/providers/ollama.rs diff --git a/server/src/provider/ollama/request.rs b/server/src/provider/providers/ollama/request.rs similarity index 100% rename from server/src/provider/ollama/request.rs rename to server/src/provider/providers/ollama/request.rs diff --git a/server/src/provider/ollama/response.rs b/server/src/provider/providers/ollama/response.rs similarity index 100% rename from server/src/provider/ollama/response.rs rename to server/src/provider/providers/ollama/response.rs diff --git a/server/src/provider/openai.rs b/server/src/provider/providers/openai.rs similarity index 100% rename from server/src/provider/openai.rs rename to server/src/provider/providers/openai.rs diff --git a/server/src/provider/openai/request.rs b/server/src/provider/providers/openai/request.rs similarity index 100% rename from server/src/provider/openai/request.rs rename to server/src/provider/providers/openai/request.rs diff --git a/server/src/provider/openai/response.rs b/server/src/provider/providers/openai/response.rs similarity index 100% rename from server/src/provider/openai/response.rs rename to server/src/provider/providers/openai/response.rs diff --git a/server/src/provider/utils.rs b/server/src/provider/utils.rs index 1845ab1..96c791f 100644 --- a/server/src/provider/utils.rs +++ b/server/src/provider/utils.rs @@ -1,3 +1,5 @@ +//! Utils for working with LLM responses + use rocket::futures::TryStreamExt; use serde::de::DeserializeOwned; use tokio_stream::{Stream, StreamExt}; diff --git a/server/src/storage/local.rs b/server/src/storage/local.rs index e99bd80..4c0cd9f 100644 --- a/server/src/storage/local.rs +++ b/server/src/storage/local.rs @@ -113,7 +113,8 @@ impl LocalStorage { } } -/// Reads a file as a base64 encoded string (synchronous because `base64` crate writer is synchronous). +/// Synchronously read a file as a base64 encoded string. +/// (This is synchronous because the `base64` crate is synchronous.) fn read_base64(path: &Path) -> IoResult { let mut file = std::fs::File::open(path)?; let file_size = file.metadata()?.len(); diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 0c3f499..7f2b99c 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -304,7 +304,7 @@ impl LlmStreamWriter { mod tests { use super::*; use crate::{ - provider::{lorem::LoremProvider, LlmApiProvider, LlmProviderOptions}, + provider::{providers::LoremProvider, LlmApiProvider, LlmProviderOptions}, redis::{ExclusiveClientManager, ExclusiveClientPool}, stream::{cancel_current_chat_stream, check_chat_stream_exists}, }; From 3318bff5317b9fb805f219922615d1671810ac3c Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Tue, 2 Sep 2025 00:38:48 -0400 Subject: [PATCH 16/26] Update README.md --- README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/README.md b/README.md index bb44dd1..570d6cd 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,7 @@ A fast, secure, self-hostable chat application built with Rust, TypeScript, and React. Chat with multiple AI providers using your own API keys, with real-time streaming built-in. -!! **Submission to the [T3 Chat Cloneathon](https://cloneathon.t3.chat/)** !! - -Demo link: https://rschat.fasharp.io (⚠️ This is a demo - don't expect your account/chats to be there when you come back. It may intermittently delete data. Please also don't enter any sensitive information or confidential data) +Demo link: https://rs-chat-demo.up.railway.app/ (⚠️ This is a demo - don't expect your account/chats to be there when you come back. It may intermittently delete all data. Please also don't enter any sensitive information or confidential data) ## ✨ Features From 0b1a76920bea5be671b4db80f41c38e1babb66bf Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Tue, 2 Sep 2025 00:43:16 -0400 Subject: [PATCH 17/26] server: move import --- server/src/provider.rs | 2 +- server/src/provider/providers/anthropic/request.rs | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/server/src/provider.rs b/server/src/provider.rs index 14e2e77..9e3c520 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -10,7 +10,7 @@ mod utils; use crate::{ db::{ - models::{ChatRsFileType, ChatRsMessage, ChatRsMessageRole, ChatRsProviderType}, + models::{ChatRsMessage, ChatRsMessageRole, ChatRsProviderType}, services::FileDbService, DbConnection, }, diff --git a/server/src/provider/providers/anthropic/request.rs b/server/src/provider/providers/anthropic/request.rs index 2208580..06cc9f7 100644 --- a/server/src/provider/providers/anthropic/request.rs +++ b/server/src/provider/providers/anthropic/request.rs @@ -2,7 +2,10 @@ use std::collections::HashMap; use serde::Serialize; -use crate::provider::*; +use crate::{ + db::models::ChatRsFileType, + provider::{LlmMessage, LlmTool}, +}; pub fn build_anthropic_messages<'a>( messages: &'a [LlmMessage], From f6e4a061b633223a77301623beec81144bd6161a Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Tue, 2 Sep 2025 00:48:52 -0400 Subject: [PATCH 18/26] server: tweak serialization --- server/src/tools.rs | 2 ++ server/src/tools/system.rs | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/server/src/tools.rs b/server/src/tools.rs index 598e15e..32bd947 100644 --- a/server/src/tools.rs +++ b/server/src/tools.rs @@ -18,7 +18,9 @@ use { /// User configuration of tools when sending a chat message #[derive(Debug, Default, PartialEq, JsonSchema, serde::Serialize, serde::Deserialize)] pub struct SendChatToolInput { + #[serde(skip_serializing_if = "Option::is_none")] pub system: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub external_apis: Option>, } diff --git a/server/src/tools/system.rs b/server/src/tools/system.rs index f53f721..b197666 100644 --- a/server/src/tools/system.rs +++ b/server/src/tools/system.rs @@ -45,7 +45,7 @@ pub struct SystemToolInput { /// Enable/disable tools to get system information, current date/time, etc. #[serde(default)] info: bool, - #[serde(default)] + #[serde(default, skip_serializing_if = "Option::is_none")] files: Option, } impl SystemToolInput { From 06810cb30e7b2bfaf01c7987696b1859435af18d Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Tue, 2 Sep 2025 01:06:14 -0400 Subject: [PATCH 19/26] server: fix openai file types --- server/src/provider/providers/openai/request.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/provider/providers/openai/request.rs b/server/src/provider/providers/openai/request.rs index 5fcb774..5df6445 100644 --- a/server/src/provider/providers/openai/request.rs +++ b/server/src/provider/providers/openai/request.rs @@ -130,7 +130,7 @@ pub struct OpenAIMessage<'a> { } #[derive(Debug, Serialize)] -#[serde(tag = "type", rename = "snakes_case")] +#[serde(tag = "type", rename_all = "snake_case")] pub enum OpenAIContent<'a> { Text { text: &'a str, From 2799d2f6f3a2242fdfaf25b48529c35982986e1b Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Tue, 2 Sep 2025 01:56:01 -0400 Subject: [PATCH 20/26] server: fix base64 URLs --- .../src/provider/providers/openai/request.rs | 34 ++++++++++++------- server/src/storage.rs | 7 ++-- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/server/src/provider/providers/openai/request.rs b/server/src/provider/providers/openai/request.rs index 5df6445..51e9820 100644 --- a/server/src/provider/providers/openai/request.rs +++ b/server/src/provider/providers/openai/request.rs @@ -21,10 +21,14 @@ pub fn build_openai_messages<'a>(messages: &'a [LlmMessage]) -> Vec OpenAIContent::Text { text: &file.content, }, - ChatRsFileType::Image => OpenAIContent::ImageUrl { url: &file.content }, + ChatRsFileType::Image => OpenAIContent::ImageUrl { + image_url: OpenAIImageUrl { url: &file.content }, + }, ChatRsFileType::Pdf => OpenAIContent::File { - file_data: &file.content, - filename: &file.name, + file: OpenAIFile { + file_data: &file.content, + filename: &file.name, + }, }, })); } @@ -132,16 +136,20 @@ pub struct OpenAIMessage<'a> { #[derive(Debug, Serialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum OpenAIContent<'a> { - Text { - text: &'a str, - }, - ImageUrl { - url: &'a str, - }, - File { - file_data: &'a str, - filename: &'a str, - }, + Text { text: &'a str }, + ImageUrl { image_url: OpenAIImageUrl<'a> }, + File { file: OpenAIFile<'a> }, +} + +#[derive(Debug, Serialize)] +pub struct OpenAIImageUrl<'a> { + url: &'a str, +} + +#[derive(Debug, Serialize)] +pub struct OpenAIFile<'a> { + file_data: &'a str, + filename: &'a str, } /// OpenAI tool definition diff --git a/server/src/storage.rs b/server/src/storage.rs index 000cfd5..68ec40f 100644 --- a/server/src/storage.rs +++ b/server/src/storage.rs @@ -30,7 +30,7 @@ pub fn setup_storage() -> AdHoc { } impl ChatRsFile { - /// Get the file type and contents for LLM input. Uses base64 encoding for image and PDF files. + /// Get the file type and contents for LLM input. Uses base64 URLs for image and PDF files. pub async fn read_to_string( &self, session_id: Option<&Uuid>, @@ -45,9 +45,10 @@ impl ChatRsFile { String::from_utf8_lossy(&bytes).into() } ChatRsFileType::Image | ChatRsFileType::Pdf => { - storage + let b64_content = storage .read_file_as_base64(&self.user_id, session_id, Path::new(&self.path)) - .await? + .await?; + format!("data:{};base64,{}", self.content_type, b64_content) } }; Ok((file_type, content)) From 68a2151266c57aab1855eeef5f967e58bde7c47a Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Tue, 2 Sep 2025 01:56:29 -0400 Subject: [PATCH 21/26] web: support attaching files to messages --- web/src/hooks/useChatInputState.tsx | 2 ++ web/src/lib/api/types.d.ts | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/web/src/hooks/useChatInputState.tsx b/web/src/hooks/useChatInputState.tsx index fdad653..cba1896 100644 --- a/web/src/hooks/useChatInputState.tsx +++ b/web/src/hooks/useChatInputState.tsx @@ -155,6 +155,7 @@ export const useChatInputState = ({ max_tokens: maxTokens, }, tools: toolInput, + files: files.length > 0 ? files.map((file) => file.id) : undefined, }); formRef.current?.reset(); }, [ @@ -162,6 +163,7 @@ export const useChatInputState = ({ selectedProvider, modelId, toolInput, + files, temperature, maxTokens, onSubmit, diff --git a/web/src/lib/api/types.d.ts b/web/src/lib/api/types.d.ts index 1f78ff3..8bf8a21 100644 --- a/web/src/lib/api/types.d.ts +++ b/web/src/lib/api/types.d.ts @@ -755,11 +755,17 @@ export interface components { /** @enum {string} */ ChatRsMessageRole: "User" | "Assistant" | "System" | "Tool"; ChatRsMessageMeta: { + /** @description User messages: metadata associated with the user message */ + user?: components["schemas"]["UserMeta"] | null; /** @description Assistant messages: metadata associated with the assistant message */ assistant?: components["schemas"]["AssistantMeta"] | null; /** @description Tool messages: metadata of the executed tool call */ tool_call?: components["schemas"]["ChatRsExecutedToolCall"] | null; }; + UserMeta: { + /** @description The IDs of the files attached to this message */ + files?: string[] | null; + }; AssistantMeta: { /** * Format: int32 @@ -777,7 +783,7 @@ export interface components { /** @description Whether this is a partial and/or interrupted message */ partial?: boolean | null; }; - /** @description Shared configuration for LLM provider requests */ + /** @description Configuration for LLM provider requests */ LlmProviderOptions: { model: string; /** Format: float */ @@ -891,6 +897,8 @@ export interface components { options: components["schemas"]["LlmProviderOptions"]; /** @description Configuration of tools available to the assistant */ tools?: components["schemas"]["SendChatToolInput"] | null; + /** @description IDs of the file(s) to attach to this message */ + files?: string[] | null; }; GetAllToolsResponse: { /** @description System tools */ From 155d7fa9d2f3f67b03389e9a3ab7d4490417dffe Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Tue, 2 Sep 2025 02:15:50 -0400 Subject: [PATCH 22/26] web: show user's file attachments below message --- web/src/components/chat/ChatMessages.tsx | 3 ++ .../components/chat/messages/ChatMessage.tsx | 44 ++++++++++++++----- .../app/_appLayout/session/$sessionId.tsx | 3 ++ 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/web/src/components/chat/ChatMessages.tsx b/web/src/components/chat/ChatMessages.tsx index 63a41a3..aebfa20 100644 --- a/web/src/components/chat/ChatMessages.tsx +++ b/web/src/components/chat/ChatMessages.tsx @@ -9,6 +9,7 @@ interface Props { messages: Array; providers?: Array; tools?: components["schemas"]["GetAllToolsResponse"]; + files?: Array; onToolExecute: ( messageId: string, sessionId: string, @@ -24,6 +25,7 @@ export default memo(function ChatMessages({ messages, providers, tools, + files, onToolExecute, isStreaming, sessionId, @@ -56,6 +58,7 @@ export default memo(function ChatMessages({ user={user} providers={providers} tools={tools} + files={files} executedToolCalls={message.meta.assistant?.tool_calls?.filter((tc) => messages.some((m) => m.meta.tool_call?.id === tc.id), )} diff --git a/web/src/components/chat/messages/ChatMessage.tsx b/web/src/components/chat/messages/ChatMessage.tsx index 2eaca01..394702b 100644 --- a/web/src/components/chat/messages/ChatMessage.tsx +++ b/web/src/components/chat/messages/ChatMessage.tsx @@ -1,4 +1,4 @@ -import { Bot, Wrench } from "lucide-react"; +import { Bot, Paperclip, Wrench } from "lucide-react"; import React, { Suspense } from "react"; import Markdown from "react-markdown"; @@ -7,6 +7,7 @@ import { ChatBubbleAvatar, ChatBubbleMessage, } from "@/components/ui/chat/chat-bubble"; +import { API_URL } from "@/lib/api/client"; import type { components } from "@/lib/api/types"; import { cn } from "@/lib/utils"; import { CopyButton, DeleteButton, InfoButton } from "./ChatMessageActions"; @@ -24,6 +25,7 @@ interface Props { message: components["schemas"]["ChatRsMessage"]; user?: components["schemas"]["ChatRsUser"]; tools?: components["schemas"]["GetAllToolsResponse"]; + files?: components["schemas"]["ChatRsFile"][]; executedToolCalls?: components["schemas"]["ChatRsToolCall"][]; onExecuteToolCall: (messageId: string, toolCallId: string) => void; providers?: components["schemas"]["ChatRsProvider"][]; @@ -34,6 +36,7 @@ export default function ChatMessage({ message, user, tools, + files, executedToolCalls, onExecuteToolCall, providers, @@ -96,18 +99,35 @@ export default function ChatMessage({ )} {message.role === "User" && ( -
-
- - onDeleteMessage(message.id)} - variant="default" - /> -
-
- {formatDate(message.created_at)} + <> + {message.meta.user?.files?.map((fileId) => ( + <> + + + ))} +
+
+ + onDeleteMessage(message.id)} + variant="default" + /> +
+
+ {formatDate(message.created_at)} +
-
+ )} {message.role === "Tool" && (
diff --git a/web/src/routes/app/_appLayout/session/$sessionId.tsx b/web/src/routes/app/_appLayout/session/$sessionId.tsx index a290980..dfb1980 100644 --- a/web/src/routes/app/_appLayout/session/$sessionId.tsx +++ b/web/src/routes/app/_appLayout/session/$sessionId.tsx @@ -17,6 +17,7 @@ import { getChatSession, useGetChatSession, } from "@/lib/api/session"; +import { useSessionFiles } from "@/lib/api/storage"; import { useTools } from "@/lib/api/tool"; import type { components } from "@/lib/api/types"; import { useStreamingChats, useStreamingTools } from "@/lib/context"; @@ -48,6 +49,7 @@ function RouteComponent() { const { data: session } = useGetChatSession(sessionId); const { data: providers } = useProviders(); const { data: tools } = useTools(); + const { data: files } = useSessionFiles(sessionId); const { streamedChats, onUserSubmit } = useStreamingChats(); const { streamedTools, onToolExecute, onToolCancel } = useStreamingTools(); @@ -113,6 +115,7 @@ function RouteComponent() { messages={session?.messages || []} providers={providers} tools={tools} + files={files} sessionId={sessionId} onToolExecute={onToolExecute} isStreaming={currentStream?.status === "streaming"} From 99438228834b2d51c611d9f44323805d7fc6ae18 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Tue, 2 Sep 2025 03:01:18 -0400 Subject: [PATCH 23/26] web: allow uploading through file dialog --- web/src/components/chat/ChatMessageInput.tsx | 10 +- .../components/chat/messages/ChatMessage.tsx | 24 +- .../chat/settings/ChatFileSelect.tsx | 267 ++++++++++++------ 3 files changed, 193 insertions(+), 108 deletions(-) diff --git a/web/src/components/chat/ChatMessageInput.tsx b/web/src/components/chat/ChatMessageInput.tsx index 9a84c9c..018665d 100644 --- a/web/src/components/chat/ChatMessageInput.tsx +++ b/web/src/components/chat/ChatMessageInput.tsx @@ -68,6 +68,8 @@ export default memo(function ChatMessageInput({ [providerId, onSelectModel], ); + const [isFileDialogOpen, setIsFileDialogOpen] = useState(false); + const [enterKeyShouldSubmit, setEnterKeyShouldSubmit] = useState(true); const onKeyDown = useCallback( (e: React.KeyboardEvent) => { @@ -132,13 +134,16 @@ export default memo(function ChatMessageInput({ const onDrop = useCallback( (acceptedFiles: File[]) => { - handleFileUpload(acceptedFiles); + if (!isFileDialogOpen) { + handleFileUpload(acceptedFiles); + } }, - [handleFileUpload], + [handleFileUpload, isFileDialogOpen], ); const { getRootProps, getInputProps, isDragActive } = useDropzone({ onDrop, + disabled: isFileDialogOpen, noClick: true, noKeyboard: true, }); @@ -205,6 +210,7 @@ export default memo(function ChatMessageInput({ onAddFile={onAddFile} onRemoveFile={onRemoveFile} onRemoveAllFiles={onRemoveAllFiles} + onOpenChange={setIsFileDialogOpen} /> {files.length > 0 && (
diff --git a/web/src/components/chat/messages/ChatMessage.tsx b/web/src/components/chat/messages/ChatMessage.tsx index 394702b..8162e0d 100644 --- a/web/src/components/chat/messages/ChatMessage.tsx +++ b/web/src/components/chat/messages/ChatMessage.tsx @@ -101,19 +101,17 @@ export default function ChatMessage({ {message.role === "User" && ( <> {message.meta.user?.files?.map((fileId) => ( - <> - - + ))}
diff --git a/web/src/components/chat/settings/ChatFileSelect.tsx b/web/src/components/chat/settings/ChatFileSelect.tsx index cc717fc..41090db 100644 --- a/web/src/components/chat/settings/ChatFileSelect.tsx +++ b/web/src/components/chat/settings/ChatFileSelect.tsx @@ -1,5 +1,6 @@ -import { Check, File, FileText, Image, Paperclip } from "lucide-react"; +import { Check, File, FileText, Image, Paperclip, Upload } from "lucide-react"; import { useCallback, useMemo, useState } from "react"; +import { useDropzone } from "react-dropzone"; import { Button } from "@/components/ui/button"; import { @@ -10,7 +11,7 @@ import { DialogTitle, DialogTrigger, } from "@/components/ui/dialog"; -import { useSessionFiles } from "@/lib/api/storage"; +import { useSessionFiles, useUploadFile } from "@/lib/api/storage"; import type { components } from "@/lib/api/types"; import { cn } from "@/lib/utils"; import ChatSettingsBadge from "./ChatSettingsBadge"; @@ -21,6 +22,7 @@ interface FileSelectionDialogProps { onAddFile: (file: components["schemas"]["ChatRsFile"]) => void; onRemoveFile: (fileId: string) => void; onRemoveAllFiles: () => void; + onOpenChange?: (open: boolean) => void; } function getFileIcon(fileType: components["schemas"]["ChatRsFileType"]) { @@ -48,9 +50,21 @@ export default function ChatFileSelect({ onAddFile, onRemoveFile, onRemoveAllFiles, + onOpenChange, }: FileSelectionDialogProps) { + const { data: files, isLoading } = useSessionFiles(sessionId); + const { mutate: uploadFile } = useUploadFile(); + const [open, setOpen] = useState(false); - const { data: files, isLoading } = useSessionFiles(sessionId, open); + const handleOpenChange = useCallback( + (open: boolean) => { + setOpen(open); + onOpenChange?.(open); + }, + [onOpenChange], + ); + + const [uploadingFiles, setUploadingFiles] = useState([]); const fileList = useMemo(() => files || [], [files]); @@ -77,11 +91,55 @@ export default function ChatFileSelect({ onRemoveAllFiles(); }, [onRemoveAllFiles]); + const handleFileUpload = useCallback( + (files: File[]) => { + if (!sessionId) return; + + const fileNames = files.map((f) => f.name); + setUploadingFiles((prev) => [...prev, ...fileNames]); + + files.forEach((file) => { + uploadFile( + { + sessionId, + path: file.name, + file, + }, + { + onSettled: () => { + setUploadingFiles((prev) => + prev.filter((name) => name !== file.name), + ); + }, + onSuccess: (file) => onAddFile(file), + onError: (error) => { + console.error(`Failed to upload ${file.name}:`, error); + }, + }, + ); + }); + }, + [sessionId, uploadFile, onAddFile], + ); + + const onDrop = useCallback( + (acceptedFiles: File[]) => { + handleFileUpload(acceptedFiles); + }, + [handleFileUpload], + ); + + const { getRootProps, getInputProps, isDragActive } = useDropzone({ + onDrop, + noClick: true, + noKeyboard: true, + }); + const selectedCount = selectedFiles.length; const totalCount = fileList.length; return ( - + - +
+ + {isDragActive && ( +
+
+ +

+ Drop files here to upload +

)} + + Attach Files + Attach files to your message. + -
- {isLoading ? ( -
- Loading files... +
+ {totalCount > 0 && ( +
+ + {selectedCount} of {totalCount} files selected + +
+ + +
- ) : fileList.length === 0 ? ( -
- -

No files in this session

-

- Upload files by dragging them to the chat input -

+ )} + + {uploadingFiles.length > 0 && ( +
+ + Uploading {uploadingFiles.length} file + {uploadingFiles.length > 1 ? "s" : ""}...
- ) : ( -
- {fileList.map((file) => { - const isSelected = selectedFiles.some( - (f) => f.id === file.id, - ); - return ( -
-
- {isSelected ? ( -
- +
+

+ {file.path} +

+
+ {file.file_type.toUpperCase()} + + {formatFileSize(file.size)} +
- ) : ( -
- )} -
- - ); - })} -
- )} -
+
+
+ {isSelected ? ( +
+ +
+ ) : ( +
+ )} +
+ + ); + })} +
+ )} +
-
- +
+ +
From af189d822f81b14a2ef1d2db1cc156a809d0e944 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Tue, 2 Sep 2025 03:01:56 -0400 Subject: [PATCH 24/26] server: set openai store parameter --- server/src/provider/providers/openai.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/server/src/provider/providers/openai.rs b/server/src/provider/providers/openai.rs index 366cace..edf2c06 100644 --- a/server/src/provider/providers/openai.rs +++ b/server/src/provider/providers/openai.rs @@ -125,6 +125,7 @@ impl LlmApiProvider for OpenAIProvider { }], max_tokens: options.max_tokens, temperature: options.temperature, + store: (self.base_url == OPENAI_API_BASE_URL).then_some(false), ..Default::default() }; From a0f94676db10ac590a2b27ece0194c1212b0a2ff Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Tue, 2 Sep 2025 03:28:44 -0400 Subject: [PATCH 25/26] server: fix anthropic file format --- .../src/provider/providers/anthropic/request.rs | 2 -- server/src/provider/providers/openai/request.rs | 16 +++++++++------- server/src/provider/utils.rs | 7 ++++++- server/src/storage.rs | 5 ++--- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/server/src/provider/providers/anthropic/request.rs b/server/src/provider/providers/anthropic/request.rs index 06cc9f7..9659861 100644 --- a/server/src/provider/providers/anthropic/request.rs +++ b/server/src/provider/providers/anthropic/request.rs @@ -38,7 +38,6 @@ pub fn build_anthropic_messages<'a>( }, }, ChatRsFileType::Image => AnthropicContentBlock::Image { - title: &file.name, source: AnthropicSource::Base64 { data: &file.content, media_type: &file.content_type, @@ -146,7 +145,6 @@ pub enum AnthropicContentBlock<'a> { text: &'a str, }, Image { - title: &'a str, source: AnthropicSource<'a>, }, Document { diff --git a/server/src/provider/providers/openai/request.rs b/server/src/provider/providers/openai/request.rs index 51e9820..87457d3 100644 --- a/server/src/provider/providers/openai/request.rs +++ b/server/src/provider/providers/openai/request.rs @@ -2,7 +2,7 @@ use serde::Serialize; use crate::{ db::models::ChatRsFileType, - provider::{LlmMessage, LlmTool}, + provider::{utils::create_data_uri, LlmMessage, LlmTool}, }; pub fn build_openai_messages<'a>(messages: &'a [LlmMessage]) -> Vec> { @@ -22,11 +22,13 @@ pub fn build_openai_messages<'a>(messages: &'a [LlmMessage]) -> Vec OpenAIContent::ImageUrl { - image_url: OpenAIImageUrl { url: &file.content }, + image_url: OpenAIImageUrl { + url: create_data_uri(&file.content_type, &file.content), + }, }, ChatRsFileType::Pdf => OpenAIContent::File { file: OpenAIFile { - file_data: &file.content, + file_data: create_data_uri(&file.content_type, &file.content), filename: &file.name, }, }, @@ -137,18 +139,18 @@ pub struct OpenAIMessage<'a> { #[serde(tag = "type", rename_all = "snake_case")] pub enum OpenAIContent<'a> { Text { text: &'a str }, - ImageUrl { image_url: OpenAIImageUrl<'a> }, + ImageUrl { image_url: OpenAIImageUrl }, File { file: OpenAIFile<'a> }, } #[derive(Debug, Serialize)] -pub struct OpenAIImageUrl<'a> { - url: &'a str, +pub struct OpenAIImageUrl { + url: String, } #[derive(Debug, Serialize)] pub struct OpenAIFile<'a> { - file_data: &'a str, + file_data: String, filename: &'a str, } diff --git a/server/src/provider/utils.rs b/server/src/provider/utils.rs index 96c791f..713a66c 100644 --- a/server/src/provider/utils.rs +++ b/server/src/provider/utils.rs @@ -1,4 +1,4 @@ -//! Utils for working with LLM responses +//! Utils for working with LLM requests and responses use rocket::futures::TryStreamExt; use serde::de::DeserializeOwned; @@ -10,6 +10,11 @@ use tokio_util::{ use crate::provider::LlmStreamError; +/// Create a data URI +pub fn create_data_uri(content_type: &str, b64_string: &str) -> String { + format!("data:{};base64,{}", content_type, b64_string) +} + /// Get a stream of deserialized events from a provider SSE stream. pub fn get_sse_events( response: reqwest::Response, diff --git a/server/src/storage.rs b/server/src/storage.rs index 68ec40f..5d90cc1 100644 --- a/server/src/storage.rs +++ b/server/src/storage.rs @@ -45,10 +45,9 @@ impl ChatRsFile { String::from_utf8_lossy(&bytes).into() } ChatRsFileType::Image | ChatRsFileType::Pdf => { - let b64_content = storage + storage .read_file_as_base64(&self.user_id, session_id, Path::new(&self.path)) - .await?; - format!("data:{};base64,{}", self.content_type, b64_content) + .await? } }; Ok((file_type, content)) From 0ddcbbde61ec70d8075df2f6024ae283a2288d66 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Tue, 2 Sep 2025 04:01:00 -0400 Subject: [PATCH 26/26] docker: fix data/storage permissions Dockerfile sets up proper file permissions for the non-root user to perform file operations in the /data directory --- Dockerfile | 6 +++++- docker-compose.yml | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 233adbf..8470909 100644 --- a/Dockerfile +++ b/Dockerfile @@ -44,7 +44,7 @@ RUN apt-get update -qq && \ apt-get install -y -qq ca-certificates libpq5 && \ apt-get clean -# Use non-root user +# Create non-root user and data directory ARG UID=10001 RUN adduser \ --disabled-password \ @@ -53,6 +53,9 @@ RUN adduser \ --shell "/sbin/nologin" \ --uid "${UID}" \ appuser +RUN mkdir -p /data +RUN chown -R appuser:appuser /data + USER appuser # Copy app files @@ -61,6 +64,7 @@ COPY --from=backend-build /app/run-server /usr/local/bin/ # Run ENV RS_CHAT_STATIC_PATH=/var/www +ENV RS_CHAT_DATA_DIR=/data ENV RS_CHAT_ADDRESS=0.0.0.0 ENV RS_CHAT_PORT=8080 EXPOSE 8080 diff --git a/docker-compose.yml b/docker-compose.yml index 251c5d0..269973a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -29,9 +29,11 @@ services: RS_CHAT_SERVER_ADDRESS: http://localhost:8080 RS_CHAT_DATABASE_URL: postgres://postgres:postgres@postgres/postgres RS_CHAT_REDIS_URL: redis://redis:6379 + RS_CHAT_DATA_DIR: /data env_file: server/.env volumes: - ./.docker:/certs + - rschat_data:/data depends_on: - db - redis @@ -39,3 +41,4 @@ services: volumes: postgres_data: redis_data: + rschat_data: