From 68796c35ee7ed5f763f4e3039ae6be0b88ba0ee0 Mon Sep 17 00:00:00 2001 From: BYND Date: Thu, 21 May 2026 17:38:47 +0300 Subject: [PATCH] Add file retrieval indexing toolkit --- docker-compose.yml | 4 +- .../migration.sql | 27 +++ prisma/schema.prisma | 40 +++- .../_components/chat/input/index.tsx | 31 ++- .../components/tabs/attachments/table.tsx | 11 ++ src/app/api/files/upload/route.ts | 15 +- src/lib/rag/embeddings.ts | 66 +++++++ src/lib/rag/files.ts | 180 ++++++++++++++++++ src/server/api/routers/files.ts | 4 + src/server/rag/files.ts | 123 ++++++++++++ src/toolkits/toolkits/client.ts | 2 + src/toolkits/toolkits/file-rag/base.ts | 18 ++ src/toolkits/toolkits/file-rag/client.tsx | 24 +++ src/toolkits/toolkits/file-rag/server.ts | 27 +++ .../toolkits/file-rag/tools/client.ts | 1 + src/toolkits/toolkits/file-rag/tools/index.ts | 5 + .../tools/search-uploaded-files/base.ts | 35 ++++ .../tools/search-uploaded-files/client.tsx | 71 +++++++ .../tools/search-uploaded-files/server.ts | 29 +++ .../toolkits/file-rag/tools/server.ts | 1 + src/toolkits/toolkits/server.ts | 2 + src/toolkits/toolkits/shared.ts | 5 + 22 files changed, 705 insertions(+), 16 deletions(-) create mode 100644 prisma/migrations/20260521000100_file_embeddings/migration.sql create mode 100644 src/lib/rag/embeddings.ts create mode 100644 src/lib/rag/files.ts create mode 100644 src/server/rag/files.ts create mode 100644 src/toolkits/toolkits/file-rag/base.ts create mode 100644 src/toolkits/toolkits/file-rag/client.tsx create mode 100644 src/toolkits/toolkits/file-rag/server.ts create mode 100644 src/toolkits/toolkits/file-rag/tools/client.ts create mode 100644 src/toolkits/toolkits/file-rag/tools/index.ts create mode 100644 src/toolkits/toolkits/file-rag/tools/search-uploaded-files/base.ts create mode 100644 src/toolkits/toolkits/file-rag/tools/search-uploaded-files/client.tsx create mode 100644 src/toolkits/toolkits/file-rag/tools/search-uploaded-files/server.ts create mode 100644 src/toolkits/toolkits/file-rag/tools/server.ts diff --git a/docker-compose.yml b/docker-compose.yml index 4265d9c9..010d4884 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,7 +3,7 @@ version: "3.8" services: # PostgreSQL Database postgres: - image: postgres:15 + image: pgvector/pgvector:pg15 container_name: toolkit-postgres environment: POSTGRES_USER: postgres @@ -48,4 +48,4 @@ services: volumes: postgres_data: redis_data: - blob_data: \ No newline at end of file + blob_data: diff --git a/prisma/migrations/20260521000100_file_embeddings/migration.sql b/prisma/migrations/20260521000100_file_embeddings/migration.sql new file mode 100644 index 00000000..8ba192f4 --- /dev/null +++ b/prisma/migrations/20260521000100_file_embeddings/migration.sql @@ -0,0 +1,27 @@ +CREATE EXTENSION IF NOT EXISTS vector; + +CREATE TYPE "FileEmbeddingStatus" AS ENUM ('skipped', 'indexing', 'indexed', 'failed'); + +ALTER TABLE "File" +ADD COLUMN "embeddingStatus" "FileEmbeddingStatus" NOT NULL DEFAULT 'skipped'; + +CREATE TABLE "FileEmbeddingChunk" ( + "id" TEXT NOT NULL, + "fileId" TEXT NOT NULL, + "userId" TEXT NOT NULL, + "chunkIndex" INTEGER NOT NULL, + "content" TEXT NOT NULL, + "embedding" vector(384) NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "FileEmbeddingChunk_pkey" PRIMARY KEY ("id") +); + +CREATE UNIQUE INDEX "FileEmbeddingChunk_fileId_chunkIndex_key" ON "FileEmbeddingChunk"("fileId", "chunkIndex"); +CREATE INDEX "FileEmbeddingChunk_fileId_idx" ON "FileEmbeddingChunk"("fileId"); +CREATE INDEX "FileEmbeddingChunk_userId_idx" ON "FileEmbeddingChunk"("userId"); +CREATE INDEX "FileEmbeddingChunk_embedding_idx" ON "FileEmbeddingChunk" USING ivfflat ("embedding" vector_cosine_ops) WITH (lists = 100); + +ALTER TABLE "FileEmbeddingChunk" +ADD CONSTRAINT "FileEmbeddingChunk_fileId_fkey" +FOREIGN KEY ("fileId") REFERENCES "File"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 153dbc73..daf0420f 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -115,13 +115,37 @@ model Video { } model File { - id String @id @default(uuid()) - userId String - name String - contentType String - url String - createdAt DateTime @default(now()) - user User @relation(fields: [userId], references: [id], onDelete: Cascade) + id String @id @default(uuid()) + userId String + name String + contentType String + url String + embeddingStatus FileEmbeddingStatus @default(skipped) + createdAt DateTime @default(now()) + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + embeddingChunks FileEmbeddingChunk[] +} + +enum FileEmbeddingStatus { + skipped + indexing + indexed + failed +} + +model FileEmbeddingChunk { + id String @id @default(uuid()) + fileId String + userId String + chunkIndex Int + content String @db.Text + embedding Unsupported("vector(384)") + createdAt DateTime @default(now()) + file File @relation(fields: [fileId], references: [id], onDelete: Cascade) + + @@unique([fileId, chunkIndex]) + @@index([fileId]) + @@index([userId]) } model Image { @@ -190,4 +214,4 @@ model Tool { toolkit Toolkit @relation(fields: [toolkitId], references: [id], onDelete: Cascade) @@id([id, toolkitId]) -} \ No newline at end of file +} diff --git a/src/app/(general)/_components/chat/input/index.tsx b/src/app/(general)/_components/chat/input/index.tsx index 29428372..9baeefa8 100644 --- a/src/app/(general)/_components/chat/input/index.tsx +++ b/src/app/(general)/_components/chat/input/index.tsx @@ -96,6 +96,9 @@ const PureMultimodalInput: React.FC = ({ "input", "", ); + const [fileIndexingEnabled, setFileIndexingEnabled] = useLocalStorage< + boolean | null + >("file-indexing-enabled", null); useEffect(() => { if (textareaRef.current) { @@ -121,6 +124,18 @@ const PureMultimodalInput: React.FC = ({ const fileInputRef = useRef(null); const [uploadQueue, setUploadQueue] = useState>([]); + const requestFileIndexingPreference = useCallback(() => { + if (fileIndexingEnabled !== null) { + return fileIndexingEnabled; + } + + const shouldIndex = window.confirm( + "Index uploaded files for retrieval? This stores searchable chunks in your database and enables the File Search toolkit.", + ); + setFileIndexingEnabled(shouldIndex); + return shouldIndex; + }, [fileIndexingEnabled, setFileIndexingEnabled]); + const supportsImages = selectedChatModel?.capabilities?.includes( LanguageModelCapability.Vision, ); @@ -199,9 +214,10 @@ const PureMultimodalInput: React.FC = ({ ]); const uploadFile = useCallback( - async (file: File): Promise => { + async (file: File, indexFile: boolean): Promise => { const formData = new FormData(); formData.append("file", file); + formData.append("indexFile", String(indexFile)); try { const response = await fetch("/api/files/upload", { @@ -235,11 +251,15 @@ const PureMultimodalInput: React.FC = ({ const handleFileChange = useCallback( async (event: React.ChangeEvent) => { const files = Array.from(event.target.files ?? []); + const indexFiles = + files.length > 0 ? requestFileIndexingPreference() : false; setUploadQueue(files.map((file) => file.name)); try { - const uploadPromises = files.map((file) => uploadFile(file)); + const uploadPromises = files.map((file) => + uploadFile(file, indexFiles), + ); const uploadedAttachments = await Promise.all(uploadPromises); const successfullyUploadedAttachments = uploadedAttachments.filter( (attachment): attachment is Attachment => attachment !== undefined, @@ -255,7 +275,7 @@ const PureMultimodalInput: React.FC = ({ setUploadQueue([]); } }, - [setAttachments, uploadFile], + [requestFileIndexingPreference, setAttachments, uploadFile], ); useEffect(() => { @@ -302,10 +322,11 @@ const PureMultimodalInput: React.FC = ({ return; } + const indexFiles = requestFileIndexingPreference(); setUploadQueue(imageFiles.map((file) => file.name)); // Handle async upload in non-blocking way - Promise.all(imageFiles.map((file) => uploadFile(file))) + Promise.all(imageFiles.map((file) => uploadFile(file, indexFiles))) .then((uploadedAttachments) => { const successfullyUploadedAttachments = uploadedAttachments.filter( (attachment): attachment is Attachment => @@ -332,7 +353,7 @@ const PureMultimodalInput: React.FC = ({ }); } }, - [supportsImages, setAttachments, uploadFile], + [requestFileIndexingPreference, supportsImages, setAttachments, uploadFile], ); useEffect(() => { diff --git a/src/app/(general)/account/components/tabs/attachments/table.tsx b/src/app/(general)/account/components/tabs/attachments/table.tsx index 85ba6577..aff7fe70 100644 --- a/src/app/(general)/account/components/tabs/attachments/table.tsx +++ b/src/app/(general)/account/components/tabs/attachments/table.tsx @@ -36,6 +36,7 @@ export function DataTableDemo() { const [visibleColumns, setVisibleColumns] = React.useState({ name: true, contentType: true, + embeddingStatus: true, actions: true, }); @@ -184,6 +185,9 @@ export function DataTableDemo() { {visibleColumns.contentType && ( Content Type )} + {visibleColumns.embeddingStatus && ( + Indexing + )} {visibleColumns.actions && ( )} @@ -237,6 +241,13 @@ export function DataTableDemo() {
{attachment.contentType}
)} + {visibleColumns.embeddingStatus && ( + +
+ {attachment.embeddingStatus} +
+
+ )} {visibleColumns.actions && ( diff --git a/src/app/api/files/upload/route.ts b/src/app/api/files/upload/route.ts index adedde94..87c8e8d7 100644 --- a/src/app/api/files/upload/route.ts +++ b/src/app/api/files/upload/route.ts @@ -7,6 +7,7 @@ import { put } from "@vercel/blob"; import { auth } from "@/server/auth"; import { api } from "@/trpc/server"; import { FILE_MAX_SIZE, IS_DEVELOPMENT } from "@/lib/constants"; +import { indexUploadedFile } from "@/server/rag/files"; // Use Blob instead of File since File is not available in Node.js environment const FileSchema = z.object({ @@ -20,7 +21,7 @@ const FileSchema = z.object({ (file) => ["image/jpeg", "image/png", "application/pdf"].includes(file.type), { - message: "File type should be JPEG or PNG", + message: "File type should be JPEG, PNG, or PDF", }, ), }); @@ -56,6 +57,7 @@ export async function POST(request: Request) { // Get filename from formData since Blob doesn't have name property const filename = (formData.get("file") as File).name; + const shouldIndexFile = formData.get("indexFile") === "true"; const fileBuffer = await file.arrayBuffer(); try { @@ -71,8 +73,19 @@ export async function POST(request: Request) { name: filename, url: data.url, contentType: validatedFile.data.file.type, + embeddingStatus: shouldIndexFile ? "indexing" : "skipped", }); + if (shouldIndexFile) { + file.embeddingStatus = await indexUploadedFile({ + fileId: file.id, + userId: session.user.id, + name: filename, + contentType: validatedFile.data.file.type, + arrayBuffer: fileBuffer, + }); + } + if (IS_DEVELOPMENT) { // In development, use base64 data URLs so openrouter can access the file // We can't use the local blob storage because it's not accessible to openrouter diff --git a/src/lib/rag/embeddings.ts b/src/lib/rag/embeddings.ts new file mode 100644 index 00000000..03005f89 --- /dev/null +++ b/src/lib/rag/embeddings.ts @@ -0,0 +1,66 @@ +export const FILE_EMBEDDING_DIMENSIONS = 384; + +const MAX_TOKENS = 2_000; +const MAX_CHARACTER_FEATURES = 6_000; + +export function embedText(text: string) { + const vector = new Array(FILE_EMBEDDING_DIMENSIONS).fill(0); + const normalized = text.normalize("NFKC").toLowerCase(); + const tokens = normalized.match(/[\p{L}\p{N}]+/gu)?.slice(0, MAX_TOKENS); + + if (tokens?.length) { + tokens.forEach((token) => addFeature(vector, `w:${token}`, 1)); + + for (let index = 0; index < tokens.length - 1; index++) { + addFeature(vector, `b:${tokens[index]!} ${tokens[index + 1]!}`, 0.7); + } + } + + const compactText = normalized.replace(/\s+/g, " "); + const characterFeatureCount = Math.min( + compactText.length - 2, + MAX_CHARACTER_FEATURES, + ); + + for (let index = 0; index < characterFeatureCount; index++) { + addFeature(vector, `c:${compactText.slice(index, index + 3)}`, 0.25); + } + + normalizeVector(vector); + return vector; +} + +export function toVectorLiteral(vector: number[]) { + return `[${vector.map((value) => Number(value.toFixed(6))).join(",")}]`; +} + +function addFeature(vector: number[], feature: string, weight: number) { + const hash = hashString(feature); + const index = hash % FILE_EMBEDDING_DIMENSIONS; + const sign = hash & 1 ? 1 : -1; + vector[index] = (vector[index] ?? 0) + sign * weight; +} + +function normalizeVector(vector: number[]) { + const norm = Math.sqrt(vector.reduce((sum, value) => sum + value ** 2, 0)); + + if (norm === 0) { + vector[0] = 1; + return; + } + + for (let index = 0; index < vector.length; index++) { + vector[index] = vector[index]! / norm; + } +} + +function hashString(value: string) { + let hash = 2166136261; + + for (let index = 0; index < value.length; index++) { + hash ^= value.charCodeAt(index); + hash = Math.imul(hash, 16777619); + } + + return hash >>> 0; +} diff --git a/src/lib/rag/files.ts b/src/lib/rag/files.ts new file mode 100644 index 00000000..e09d428b --- /dev/null +++ b/src/lib/rag/files.ts @@ -0,0 +1,180 @@ +import { inflateSync } from "node:zlib"; + +const MAX_CHUNKS_PER_FILE = 40; +const MAX_CHUNK_LENGTH = 1_200; +const CHUNK_OVERLAP = 160; + +type IndexableFile = { + name: string; + contentType: string; + arrayBuffer: ArrayBuffer; +}; + +export function getIndexableText(file: IndexableFile) { + const metadataText = `File name: ${file.name}. MIME type: ${file.contentType}.`; + + if (file.contentType === "application/pdf") { + const pdfText = extractPdfText(Buffer.from(file.arrayBuffer)); + return cleanText(`${metadataText}\n\n${pdfText}`); + } + + if (file.contentType.startsWith("image/")) { + return `${metadataText} Uploaded image. Use the file name and surrounding chat context when searching image uploads.`; + } + + return metadataText; +} + +export function chunkText(text: string) { + const cleanedText = cleanText(text); + + if (!cleanedText) { + return []; + } + + const chunks: string[] = []; + let start = 0; + + while (start < cleanedText.length && chunks.length < MAX_CHUNKS_PER_FILE) { + let end = Math.min(start + MAX_CHUNK_LENGTH, cleanedText.length); + + if (end < cleanedText.length) { + const lastSentenceBreak = Math.max( + cleanedText.lastIndexOf(". ", end), + cleanedText.lastIndexOf("? ", end), + cleanedText.lastIndexOf("! ", end), + ); + const lastWhitespace = cleanedText.lastIndexOf(" ", end); + const preferredBreak = + lastSentenceBreak > start + MAX_CHUNK_LENGTH * 0.55 + ? lastSentenceBreak + 1 + : lastWhitespace; + + if (preferredBreak > start + MAX_CHUNK_LENGTH * 0.55) { + end = preferredBreak; + } + } + + chunks.push(cleanedText.slice(start, end).trim()); + + if (end === cleanedText.length) { + break; + } + + start = Math.max(end - CHUNK_OVERLAP, start + 1); + } + + return chunks.filter(Boolean); +} + +function extractPdfText(buffer: Buffer) { + const latin1 = buffer.toString("latin1"); + const parts: string[] = []; + + parts.push(...extractPdfTextOperators(latin1)); + + const streamPattern = /stream\r?\n([\s\S]*?)\r?\nendstream/g; + let match: RegExpExecArray | null; + + while ((match = streamPattern.exec(latin1))) { + const rawStream = trimPdfStream(Buffer.from(match[1] ?? "", "latin1")); + const dictionaryStart = Math.max(0, match.index - 700); + const dictionary = latin1.slice(dictionaryStart, match.index); + const streamBuffer = dictionary.includes("/FlateDecode") + ? tryInflate(rawStream) + : rawStream; + + if (streamBuffer) { + parts.push(...extractPdfTextOperators(streamBuffer.toString("utf8"))); + parts.push(...extractPdfTextOperators(streamBuffer.toString("latin1"))); + } + } + + return parts.join(" "); +} + +function extractPdfTextOperators(source: string) { + const parts: string[] = []; + const literalPattern = /\((?:\\.|[^\\()])*\)/g; + let match: RegExpExecArray | null; + + while ((match = literalPattern.exec(source))) { + const literal = match[0]; + const decoded = decodePdfLiteral(literal.slice(1, -1)); + + if (isReadableText(decoded)) { + parts.push(decoded); + } + } + + return parts; +} + +function decodePdfLiteral(value: string) { + return value + .replace(/\\([nrtbf()\\])/g, (_, escape: string) => { + switch (escape) { + case "n": + return "\n"; + case "r": + return "\r"; + case "t": + return "\t"; + case "b": + return "\b"; + case "f": + return "\f"; + default: + return escape; + } + }) + .replace(/\\([0-7]{1,3})/g, (_, octal: string) => + String.fromCharCode(Number.parseInt(octal, 8)), + ) + .replace(/\\\r?\n/g, ""); +} + +function isReadableText(value: string) { + const cleaned = cleanText(value); + + if (cleaned.length < 2) { + return false; + } + + const printableCharacters = cleaned.replace(/[^\x20-\x7E]/g, "").length; + return printableCharacters / cleaned.length > 0.6; +} + +function cleanText(value: string) { + return value + .replace(/\u0000/g, "") + .replace(/\s+/g, " ") + .trim(); +} + +function trimPdfStream(buffer: Buffer) { + let start = 0; + let end = buffer.length; + + if (buffer[start] === 0x0d && buffer[start + 1] === 0x0a) { + start += 2; + } else if (buffer[start] === 0x0a) { + start += 1; + } + + if (buffer[end - 2] === 0x0d && buffer[end - 1] === 0x0a) { + end -= 2; + } else if (buffer[end - 1] === 0x0a) { + end -= 1; + } + + return buffer.subarray(start, end); +} + +function tryInflate(buffer: Buffer) { + try { + return inflateSync(buffer); + } catch { + return null; + } +} diff --git a/src/server/api/routers/files.ts b/src/server/api/routers/files.ts index f9af0bd2..6c97369e 100644 --- a/src/server/api/routers/files.ts +++ b/src/server/api/routers/files.ts @@ -67,6 +67,9 @@ export const filesRouter = createTRPCRouter({ name: z.string().min(1).max(FILE_NAME_MAX_LENGTH), contentType: z.string(), url: z.string().url(), + embeddingStatus: z + .enum(["skipped", "indexing", "indexed", "failed"]) + .optional(), }), ) .mutation(async ({ ctx, input }) => { @@ -77,6 +80,7 @@ export const filesRouter = createTRPCRouter({ name: input.name, contentType: input.contentType, url: input.url, + embeddingStatus: input.embeddingStatus ?? "skipped", userId, }, }); diff --git a/src/server/rag/files.ts b/src/server/rag/files.ts new file mode 100644 index 00000000..72173099 --- /dev/null +++ b/src/server/rag/files.ts @@ -0,0 +1,123 @@ +import "server-only"; + +import { randomUUID } from "node:crypto"; + +import { Prisma } from "@prisma/client"; + +import { db } from "@/server/db"; +import { chunkText, getIndexableText } from "@/lib/rag/files"; +import { embedText, toVectorLiteral } from "@/lib/rag/embeddings"; + +type FileEmbeddingStatus = "skipped" | "indexing" | "indexed" | "failed"; + +type IndexUploadedFileInput = { + fileId: string; + userId: string; + name: string; + contentType: string; + arrayBuffer: ArrayBuffer; +}; + +type SearchIndexedFileChunksInput = { + userId: string; + query: string; + limit?: number; + fileIds?: string[]; +}; + +type SearchRow = { + chunkId: string; + fileId: string; + fileName: string; + contentType: string; + chunkIndex: number; + content: string; + score: number | string; +}; + +export async function indexUploadedFile({ + fileId, + userId, + name, + contentType, + arrayBuffer, +}: IndexUploadedFileInput): Promise { + try { + const text = getIndexableText({ name, contentType, arrayBuffer }); + const chunks = chunkText(text); + + await db.$executeRaw( + Prisma.sql`DELETE FROM "FileEmbeddingChunk" WHERE "fileId" = ${fileId}`, + ); + + if (chunks.length === 0) { + await updateFileEmbeddingStatus(fileId, "failed"); + return "failed"; + } + + for (const [chunkIndex, content] of chunks.entries()) { + const embedding = toVectorLiteral(embedText(content)); + + await db.$executeRaw( + Prisma.sql` + INSERT INTO "FileEmbeddingChunk" ("id", "fileId", "userId", "chunkIndex", "content", "embedding") + VALUES (${randomUUID()}, ${fileId}, ${userId}, ${chunkIndex}, ${content}, ${embedding}::vector) + `, + ); + } + + await updateFileEmbeddingStatus(fileId, "indexed"); + return "indexed"; + } catch (error) { + console.error("Failed to index uploaded file:", error); + await updateFileEmbeddingStatus(fileId, "failed"); + return "failed"; + } +} + +export async function searchIndexedFileChunks({ + userId, + query, + limit = 5, + fileIds, +}: SearchIndexedFileChunksInput) { + const embedding = toVectorLiteral(embedText(query)); + const sanitizedLimit = Math.min(Math.max(limit, 1), 10); + const fileFilter = fileIds?.length + ? Prisma.sql`AND c."fileId" IN (${Prisma.join(fileIds)})` + : Prisma.empty; + + const rows = await db.$queryRaw( + Prisma.sql` + SELECT + c."id" AS "chunkId", + c."fileId" AS "fileId", + f."name" AS "fileName", + f."contentType" AS "contentType", + c."chunkIndex" AS "chunkIndex", + c."content" AS "content", + 1 - (c."embedding" <=> ${embedding}::vector) AS "score" + FROM "FileEmbeddingChunk" c + INNER JOIN "File" f ON f."id" = c."fileId" + WHERE c."userId" = ${userId} + ${fileFilter} + ORDER BY c."embedding" <=> ${embedding}::vector + LIMIT ${sanitizedLimit} + `, + ); + + return rows.map((row) => ({ + ...row, + score: Number(row.score), + })); +} + +async function updateFileEmbeddingStatus( + fileId: string, + embeddingStatus: FileEmbeddingStatus, +) { + await db.file.update({ + where: { id: fileId }, + data: { embeddingStatus }, + }); +} diff --git a/src/toolkits/toolkits/client.ts b/src/toolkits/toolkits/client.ts index 29087577..a47d6265 100644 --- a/src/toolkits/toolkits/client.ts +++ b/src/toolkits/toolkits/client.ts @@ -13,6 +13,7 @@ import { googleDriveClientToolkit } from "./google-drive/client"; import { mem0ClientToolkit } from "./mem0/client"; import { notionClientToolkit } from "./notion/client"; import { e2bClientToolkit } from "./e2b/client"; +import { fileRagClientToolkit } from "./file-rag/client"; import { discordClientToolkit } from "./discord/client"; import { stravaClientToolkit } from "./strava/client"; import { spotifyClientToolkit } from "./spotify/client"; @@ -36,6 +37,7 @@ export const clientToolkits: ClientToolkits = { [Toolkits.GoogleCalendar]: googleCalendarClientToolkit, [Toolkits.Notion]: notionClientToolkit, [Toolkits.GoogleDrive]: googleDriveClientToolkit, + [Toolkits.FileRag]: fileRagClientToolkit, [Toolkits.Discord]: discordClientToolkit, [Toolkits.Strava]: stravaClientToolkit, [Toolkits.Spotify]: spotifyClientToolkit, diff --git a/src/toolkits/toolkits/file-rag/base.ts b/src/toolkits/toolkits/file-rag/base.ts new file mode 100644 index 00000000..19536bd7 --- /dev/null +++ b/src/toolkits/toolkits/file-rag/base.ts @@ -0,0 +1,18 @@ +import { z } from "zod"; + +import type { ToolkitConfig } from "@/toolkits/types"; + +import { FileRagTools } from "./tools"; +import { searchUploadedFilesTool } from "./tools/search-uploaded-files/base"; + +export const fileRagParameters = z.object({}); + +export const baseFileRagToolkitConfig: ToolkitConfig< + FileRagTools, + typeof fileRagParameters.shape +> = { + tools: { + [FileRagTools.SearchUploadedFiles]: searchUploadedFilesTool, + }, + parameters: fileRagParameters, +}; diff --git a/src/toolkits/toolkits/file-rag/client.tsx b/src/toolkits/toolkits/file-rag/client.tsx new file mode 100644 index 00000000..297d9890 --- /dev/null +++ b/src/toolkits/toolkits/file-rag/client.tsx @@ -0,0 +1,24 @@ +import { FileSearch } from "lucide-react"; + +import { createClientToolkit } from "@/toolkits/create-toolkit"; +import { ToolkitGroups } from "@/toolkits/types"; + +import { baseFileRagToolkitConfig } from "./base"; +import { FileRagTools } from "./tools"; +import { fileRagSearchUploadedFilesToolConfigClient } from "./tools/client"; + +export const fileRagClientToolkit = createClientToolkit( + baseFileRagToolkitConfig, + { + name: "File Search", + description: "Search indexed PDF and image uploads.", + icon: ({ className }) => , + form: null, + type: ToolkitGroups.KnowledgeBase, + envVars: [], + }, + { + [FileRagTools.SearchUploadedFiles]: + fileRagSearchUploadedFilesToolConfigClient, + }, +); diff --git a/src/toolkits/toolkits/file-rag/server.ts b/src/toolkits/toolkits/file-rag/server.ts new file mode 100644 index 00000000..8b75df0e --- /dev/null +++ b/src/toolkits/toolkits/file-rag/server.ts @@ -0,0 +1,27 @@ +import { auth } from "@/server/auth"; +import { createServerToolkit } from "@/toolkits/create-toolkit"; + +import { baseFileRagToolkitConfig } from "./base"; +import { FileRagTools } from "./tools"; +import { fileRagSearchUploadedFilesToolConfigServer } from "./tools/server"; + +export const fileRagToolkitServer = createServerToolkit( + baseFileRagToolkitConfig, + `You have access to the File Search toolkit for retrieval over the user's indexed uploads. This toolkit provides: + +- **Search Uploaded Files**: Find relevant chunks from PDFs and image uploads that the user opted in to index + +Use this toolkit when the user asks about uploaded files, documents, PDFs, screenshots, or images. Search first, then answer from the returned chunks. Include file names when citing or summarizing retrieved information. If no results are found, explain that the relevant upload may not have been indexed yet.`, + async () => { + const session = await auth(); + + if (!session?.user?.id) { + throw new Error("User not found"); + } + + return { + [FileRagTools.SearchUploadedFiles]: + fileRagSearchUploadedFilesToolConfigServer(session.user.id), + }; + }, +); diff --git a/src/toolkits/toolkits/file-rag/tools/client.ts b/src/toolkits/toolkits/file-rag/tools/client.ts new file mode 100644 index 00000000..f6e84df4 --- /dev/null +++ b/src/toolkits/toolkits/file-rag/tools/client.ts @@ -0,0 +1 @@ +export { fileRagSearchUploadedFilesToolConfigClient } from "./search-uploaded-files/client"; diff --git a/src/toolkits/toolkits/file-rag/tools/index.ts b/src/toolkits/toolkits/file-rag/tools/index.ts new file mode 100644 index 00000000..cfb9505e --- /dev/null +++ b/src/toolkits/toolkits/file-rag/tools/index.ts @@ -0,0 +1,5 @@ +export enum FileRagTools { + SearchUploadedFiles = "search-uploaded-files", +} + +export { searchUploadedFilesTool } from "./search-uploaded-files/base"; diff --git a/src/toolkits/toolkits/file-rag/tools/search-uploaded-files/base.ts b/src/toolkits/toolkits/file-rag/tools/search-uploaded-files/base.ts new file mode 100644 index 00000000..469f8ab6 --- /dev/null +++ b/src/toolkits/toolkits/file-rag/tools/search-uploaded-files/base.ts @@ -0,0 +1,35 @@ +import { z } from "zod"; + +import { createBaseTool } from "@/toolkits/create-tool"; + +export const searchUploadedFilesTool = createBaseTool({ + description: + "Search indexed uploaded PDF and image files using vector retrieval.", + inputSchema: z.object({ + query: z.string().describe("Search query for the uploaded files"), + limit: z + .number() + .min(1) + .max(10) + .optional() + .describe("Maximum number of chunks to return. Defaults to 5."), + fileIds: z + .array(z.string()) + .optional() + .describe("Optional file IDs to restrict retrieval to specific files."), + }), + outputSchema: z.object({ + query: z.string(), + results: z.array( + z.object({ + chunkId: z.string(), + fileId: z.string(), + fileName: z.string(), + contentType: z.string(), + chunkIndex: z.number(), + content: z.string(), + score: z.number(), + }), + ), + }), +}); diff --git a/src/toolkits/toolkits/file-rag/tools/search-uploaded-files/client.tsx b/src/toolkits/toolkits/file-rag/tools/search-uploaded-files/client.tsx new file mode 100644 index 00000000..161f2c98 --- /dev/null +++ b/src/toolkits/toolkits/file-rag/tools/search-uploaded-files/client.tsx @@ -0,0 +1,71 @@ +import React from "react"; +import { FileSearch, Search, Star } from "lucide-react"; + +import { + Accordion, + AccordionContent, + AccordionItem, + AccordionTrigger, +} from "@/components/ui/accordion"; + +import { type searchUploadedFilesTool } from "./base"; + +import type { ClientToolConfig } from "@/toolkits/types"; + +export const fileRagSearchUploadedFilesToolConfigClient: ClientToolConfig< + typeof searchUploadedFilesTool.inputSchema.shape, + typeof searchUploadedFilesTool.outputSchema.shape +> = { + CallComponent: ({ args }) => ( +
+ + + Searching uploaded files for {args.query ? `"${args.query}"` : "..."} + +
+ ), + ResultComponent: ({ result }) => { + if (result.results.length === 0) { + return ( +
+ No indexed file chunks found +
+ ); + } + + return ( + + + +

+ + Found {result.results.length} file chunk + {result.results.length === 1 ? "" : "s"} +

+
+ + {result.results.map((item) => ( +
+
+ + {item.fileName} + + + + {(item.score * 100).toFixed(0)}% + +
+

+ {item.content} +

+
+ ))} +
+
+
+ ); + }, +}; diff --git a/src/toolkits/toolkits/file-rag/tools/search-uploaded-files/server.ts b/src/toolkits/toolkits/file-rag/tools/search-uploaded-files/server.ts new file mode 100644 index 00000000..6ff51bf1 --- /dev/null +++ b/src/toolkits/toolkits/file-rag/tools/search-uploaded-files/server.ts @@ -0,0 +1,29 @@ +import { type searchUploadedFilesTool } from "./base"; +import { searchIndexedFileChunks } from "@/server/rag/files"; + +import type { ServerToolConfig } from "@/toolkits/types"; + +export const fileRagSearchUploadedFilesToolConfigServer = ( + userId: string, +): ServerToolConfig< + typeof searchUploadedFilesTool.inputSchema.shape, + typeof searchUploadedFilesTool.outputSchema.shape +> => ({ + callback: async ({ query, limit, fileIds }) => { + const results = await searchIndexedFileChunks({ + userId, + query, + limit, + fileIds, + }); + + return { + query, + results, + }; + }, + message: (result) => + result.results.length > 0 + ? "These are the most relevant indexed file chunks. Use them to answer the user and cite the file names when useful." + : "No indexed file chunks matched the query. Ask the user to upload and index relevant files if needed.", +}); diff --git a/src/toolkits/toolkits/file-rag/tools/server.ts b/src/toolkits/toolkits/file-rag/tools/server.ts new file mode 100644 index 00000000..4ae2e689 --- /dev/null +++ b/src/toolkits/toolkits/file-rag/tools/server.ts @@ -0,0 +1 @@ +export { fileRagSearchUploadedFilesToolConfigServer } from "./search-uploaded-files/server"; diff --git a/src/toolkits/toolkits/server.ts b/src/toolkits/toolkits/server.ts index ed03b433..7ddc024a 100644 --- a/src/toolkits/toolkits/server.ts +++ b/src/toolkits/toolkits/server.ts @@ -7,6 +7,7 @@ import { imageToolkitServer } from "./image/server"; import { mem0ToolkitServer } from "./mem0/server"; import { notionToolkitServer } from "./notion/server"; import { e2bToolkitServer } from "./e2b/server"; +import { fileRagToolkitServer } from "./file-rag/server"; import { discordToolkitServer } from "./discord/server"; import { stravaToolkitServer } from "./strava/server"; import { spotifyToolkitServer } from "./spotify/server"; @@ -35,6 +36,7 @@ export const serverToolkits: ServerToolkits = { [Toolkits.Memory]: mem0ToolkitServer, [Toolkits.Notion]: notionToolkitServer, [Toolkits.E2B]: e2bToolkitServer, + [Toolkits.FileRag]: fileRagToolkitServer, [Toolkits.Discord]: discordToolkitServer, [Toolkits.Strava]: stravaToolkitServer, [Toolkits.Spotify]: spotifyToolkitServer, diff --git a/src/toolkits/toolkits/shared.ts b/src/toolkits/toolkits/shared.ts index 12040b75..d9e818d2 100644 --- a/src/toolkits/toolkits/shared.ts +++ b/src/toolkits/toolkits/shared.ts @@ -14,6 +14,8 @@ import type { notionParameters } from "./notion/base"; import type { NotionTools } from "./notion/tools"; import type { e2bParameters } from "./e2b/base"; import type { E2BTools } from "./e2b/tools/tools"; +import type { fileRagParameters } from "./file-rag/base"; +import type { FileRagTools } from "./file-rag/tools"; import type { discordParameters } from "./discord/base"; import type { DiscordTools } from "./discord/tools"; import type { stravaParameters } from "./strava/base"; @@ -36,6 +38,7 @@ export enum Toolkits { Memory = "memory", Notion = "notion", E2B = "e2b", + FileRag = "file-rag", Discord = "discord", Strava = "strava", Spotify = "spotify", @@ -53,6 +56,7 @@ export type ServerToolkitNames = { [Toolkits.Memory]: Mem0Tools; [Toolkits.Notion]: NotionTools; [Toolkits.E2B]: E2BTools; + [Toolkits.FileRag]: FileRagTools; [Toolkits.Discord]: DiscordTools; [Toolkits.Strava]: StravaTools; [Toolkits.Spotify]: SpotifyTools; @@ -70,6 +74,7 @@ export type ServerToolkitParameters = { [Toolkits.Memory]: typeof mem0Parameters.shape; [Toolkits.Notion]: typeof notionParameters.shape; [Toolkits.E2B]: typeof e2bParameters.shape; + [Toolkits.FileRag]: typeof fileRagParameters.shape; [Toolkits.Discord]: typeof discordParameters.shape; [Toolkits.Strava]: typeof stravaParameters.shape; [Toolkits.Spotify]: typeof spotifyParameters.shape;