From 6b3482863434b5687bfcda3630664b11e254d246 Mon Sep 17 00:00:00 2001 From: BYND Date: Thu, 21 May 2026 11:13:05 +0300 Subject: [PATCH] Add configurable MCP toolkit support --- .../chat/messages/message-tool.tsx | 77 +++++- src/app/api/chat/route.ts | 224 +++++++++++++----- src/app/api/mcp/[server]/[transport]/route.ts | 51 ++-- src/toolkits/toolkits/client.ts | 2 + src/toolkits/toolkits/mcp/base.ts | 52 ++++ src/toolkits/toolkits/mcp/client.tsx | 20 ++ src/toolkits/toolkits/mcp/form.tsx | 53 +++++ src/toolkits/toolkits/mcp/server.ts | 11 + src/toolkits/toolkits/mcp/tools.ts | 1 + src/toolkits/toolkits/server.ts | 2 + src/toolkits/toolkits/shared.ts | 5 + 11 files changed, 412 insertions(+), 86 deletions(-) create mode 100644 src/toolkits/toolkits/mcp/base.ts create mode 100644 src/toolkits/toolkits/mcp/client.tsx create mode 100644 src/toolkits/toolkits/mcp/form.tsx create mode 100644 src/toolkits/toolkits/mcp/server.ts create mode 100644 src/toolkits/toolkits/mcp/tools.ts diff --git a/src/app/(general)/_components/chat/messages/message-tool.tsx b/src/app/(general)/_components/chat/messages/message-tool.tsx index 9cfc7399..6cadfd28 100644 --- a/src/app/(general)/_components/chat/messages/message-tool.tsx +++ b/src/app/(general)/_components/chat/messages/message-tool.tsx @@ -32,7 +32,11 @@ const MessageToolComponent: React.FC = ({ toolInvocation }) => { const { toolName } = toolInvocation; - const [server, tool] = toolName.split("_"); + const separatorIndex = toolName.indexOf("_"); + const server = + separatorIndex === -1 ? undefined : toolName.slice(0, separatorIndex); + const tool = + separatorIndex === -1 ? undefined : toolName.slice(separatorIndex + 1); if (!server || !tool) { return ( @@ -57,6 +61,16 @@ const MessageToolComponent: React.FC = ({ toolInvocation }) => { const typedTool = tool as ServerToolkitNames[typeof typedServer]; const toolConfig = clientToolkit.tools[typedTool]; + if (!toolConfig) { + return ( + + ); + } + return ( = ({ toolInvocation }) => { ); }; +const GenericToolInvocation: React.FC<{ + clientToolkit: ReturnType; + toolInvocation: ToolInvocation; + toolName: string; +}> = ({ clientToolkit, toolInvocation, toolName }) => { + const argsDefined = toolInvocation.args !== undefined; + + return ( + + + + + {toolInvocation.state === "result" ? ( + + {clientToolkit.name} Toolkit + + ) : ( + + {clientToolkit.name} Toolkit + + )} + {(toolInvocation.state === "call" || + toolInvocation.state === "partial-call") && ( + + )} + +
+

+ {toolName} +

+
+            {JSON.stringify(
+              toolInvocation.state === "result"
+                ? toolInvocation.result
+                : (toolInvocation.args ?? {}),
+              null,
+              2,
+            )}
+          
+
+
+
+ ); +}; + const areEqual = (prevProps: Props, nextProps: Props): boolean => { const { toolInvocation: prev } = prevProps; const { toolInvocation: next } = nextProps; diff --git a/src/app/api/chat/route.ts b/src/app/api/chat/route.ts index 2b858f04..bda603c1 100644 --- a/src/app/api/chat/route.ts +++ b/src/app/api/chat/route.ts @@ -5,6 +5,7 @@ import { appendResponseMessages, convertToCoreMessages, createDataStream, + experimental_createMCPClient as createMCPClient, smoothStream, tool, } from "ai"; @@ -35,9 +36,13 @@ import type { Chat } from "@prisma/client"; import { openai } from "@ai-sdk/openai"; import { getServerToolkit } from "@/toolkits/toolkits/server"; import { languageModels } from "@/ai/language"; +import { Toolkits } from "@/toolkits/toolkits/shared"; +import { mcpParameters, parseMcpHeaders } from "@/toolkits/toolkits/mcp/base"; export const maxDuration = 60; +type MCPClient = Awaited>; + let globalStreamContext: ResumableStreamContext | null = null; function getStreamContext() { @@ -62,6 +67,14 @@ function getStreamContext() { export async function POST(request: Request) { let requestBody: PostRequestBody; + const mcpClients: MCPClient[] = []; + const closeMcpClients = async () => { + await Promise.allSettled( + mcpClients.map(async (client) => { + await client.close(); + }), + ); + }; try { requestBody = postRequestBodySchema.parse(await request.json()); @@ -156,6 +169,10 @@ export async function POST(request: Request) { const toolkitTools = await Promise.all( toolkits.map(async ({ id, parameters }) => { + if (id === Toolkits.MCP) { + return createMcpToolkitTools(parameters, mcpClients); + } + const toolkit = getServerToolkit(id); const tools = await toolkit.tools(parameters); return Object.keys(tools).reduce( @@ -215,7 +232,12 @@ export async function POST(request: Request) { // Collect toolkit system prompts const toolkitSystemPrompts = await Promise.all( - toolkits.map(async ({ id }) => { + toolkits.map(async ({ id, parameters }) => { + if (id === Toolkits.MCP) { + const { url } = mcpParameters.parse(parameters); + return `You have access to a user-configured hosted MCP server at ${url}. Its tools are discovered dynamically and are prefixed with "${Toolkits.MCP}_". Use them when they match the user's request.`; + } + const toolkit = getServerToolkit(id); return toolkit.systemPrompt; }), @@ -256,6 +278,7 @@ export async function POST(request: Request) { experimental_generateMessageId: generateUUID, onError: (error) => { console.error("Stream error occurred:", error); + void closeMcpClients(); // Check if it's a 402 error and log it specifically if (error && typeof error === "object") { @@ -279,71 +302,75 @@ export async function POST(request: Request) { // Don't throw - just let the stream end naturally after sending error data }, onFinish: async ({ response }) => { - // Get the actual model used from OpenRouter's response - const [provider, modelId] = response.modelId.split("/"); - - // Try to find the model in our list first - const model = languageModels.find( - (model) => - model.provider === provider && model.modelId === modelId, - ); - - // Create model info from OpenRouter's response if not in our list - const modelInfo = model ?? { - name: `${provider}/${modelId}`, // Format nicely for display - provider: provider ?? "unknown", - modelId: modelId ?? "unknown", - }; - - // Write the model annotation - dataStream.writeMessageAnnotation({ - type: "model", - model: modelInfo, - }); + try { + // Get the actual model used from OpenRouter's response + const [provider, modelId] = response.modelId.split("/"); + + // Try to find the model in our list first + const model = languageModels.find( + (model) => + model.provider === provider && model.modelId === modelId, + ); + + // Create model info from OpenRouter's response if not in our list + const modelInfo = model ?? { + name: `${provider}/${modelId}`, // Format nicely for display + provider: provider ?? "unknown", + modelId: modelId ?? "unknown", + }; + + // Write the model annotation + dataStream.writeMessageAnnotation({ + type: "model", + model: modelInfo, + }); + + // Send modelId as message annotation + if (session.user?.id) { + try { + const assistantId = getTrailingMessageId({ + messages: response.messages.filter( + (message) => message.role === "assistant", + ), + }); - // Send modelId as message annotation - if (session.user?.id) { - try { - const assistantId = getTrailingMessageId({ - messages: response.messages.filter( - (message) => message.role === "assistant", - ), - }); - - if (!assistantId) { - throw new Error("No assistant message found!"); - } + if (!assistantId) { + throw new Error("No assistant message found!"); + } - const [, assistantMessage] = appendResponseMessages({ - messages: [message], - responseMessages: response.messages, - }); + const [, assistantMessage] = appendResponseMessages({ + messages: [message], + responseMessages: response.messages, + }); - if (!assistantMessage) { - throw new Error("No assistant message found!"); + if (!assistantMessage) { + throw new Error("No assistant message found!"); + } + + await api.messages.createMessage({ + chatId: id, + id: assistantId, + role: "assistant", + parts: assistantMessage.parts ?? [], + attachments: + assistantMessage.experimental_attachments?.map( + (attachment) => ({ + url: attachment.url, + name: attachment.name ?? "", + contentType: attachment.contentType as + | "image/png" + | "image/jpg" + | "image/jpeg", + }), + ) ?? [], + modelId: response.modelId, // Use the actual model from OpenRouter's response + }); + } catch (error) { + console.error(error); } - - await api.messages.createMessage({ - chatId: id, - id: assistantId, - role: "assistant", - parts: assistantMessage.parts ?? [], - attachments: - assistantMessage.experimental_attachments?.map( - (attachment) => ({ - url: attachment.url, - name: attachment.name ?? "", - contentType: attachment.contentType as - | "image/png" - | "image/jpg" - | "image/jpeg", - }), - ) ?? [], - modelId: response.modelId, // Use the actual model from OpenRouter's response - }); - } catch (error) { - console.error(error); } + } finally { + await closeMcpClients(); } }, tools: { @@ -377,6 +404,8 @@ export async function POST(request: Request) { return new Response(stream); } } catch (error) { + await closeMcpClients(); + if (error instanceof ChatSDKError) { return error.toResponse(); } @@ -385,6 +414,79 @@ export async function POST(request: Request) { } } +async function createMcpToolkitTools( + parameters: Record, + mcpClients: MCPClient[], +): Promise> { + const { url, headers } = mcpParameters.parse(parameters); + const mcpClient = await createMCPClient({ + transport: { + type: "sse", + url, + headers: parseMcpHeaders(headers), + }, + }); + mcpClients.push(mcpClient); + + const mcpTools = await mcpClient.tools(); + + return Object.entries(mcpTools).reduce( + (acc, [toolName, mcpTool]) => { + acc[`${Toolkits.MCP}_${toolName}`] = { + ...mcpTool, + execute: async (args, options) => { + const result = await mcpTool.execute(args, options); + + return { + result, + message: getMcpToolResultMessage(result), + }; + }, + }; + return acc; + }, + {} as Record, + ); +} + +function getMcpToolResultMessage(result: unknown) { + if (hasMcpContent(result)) { + const text = result.content + .map((content) => (isMcpTextContent(content) ? content.text : null)) + .filter((content): content is string => content !== null) + .join("\n") + .trim(); + + if (text.length > 0) { + return text; + } + } + + return "The MCP tool returned structured output. Use it to answer the user's request."; +} + +function hasMcpContent(result: unknown): result is { content: unknown[] } { + return ( + typeof result === "object" && + result !== null && + "content" in result && + Array.isArray(result.content) + ); +} + +function isMcpTextContent( + content: unknown, +): content is { type: "text"; text: string } { + return ( + typeof content === "object" && + content !== null && + "type" in content && + content.type === "text" && + "text" in content && + typeof content.text === "string" + ); +} + async function generateTitleFromUserMessage(message: UIMessage) { const { text: title } = await generateText("openai/gpt-4o-mini", { system: `\n diff --git a/src/app/api/mcp/[server]/[transport]/route.ts b/src/app/api/mcp/[server]/[transport]/route.ts index c625adf9..be3d14d3 100644 --- a/src/app/api/mcp/[server]/[transport]/route.ts +++ b/src/app/api/mcp/[server]/[transport]/route.ts @@ -1,5 +1,6 @@ import { serverToolkits } from "@/toolkits/toolkits/server"; import type { Toolkits } from "@/toolkits/toolkits/shared"; +import type { ServerTool } from "@/toolkits/types"; import { createMcpHandler } from "@vercel/mcp-adapter"; // Create a wrapper function that can access Next.js route parameters @@ -18,30 +19,32 @@ async function createHandlerWithParams( model: "openai:gpt-image-1", }); - Object.entries(tools).forEach(([toolName, tool]) => { - const { description, inputSchema, callback, message } = tool; - mcpServer.tool( - toolName, - description, - inputSchema.shape, - async (args) => { - const result = await callback(args); - return { - content: [ - { - type: "text", - text: message - ? typeof message === "function" - ? message(result) - : message - : JSON.stringify(result, null, 2), - }, - ], - structuredContent: result, - }; - }, - ); - }); + Object.entries(tools as Record).forEach( + ([toolName, tool]) => { + const { description, inputSchema, callback, message } = tool; + mcpServer.tool( + toolName, + description, + inputSchema.shape, + async (args: Record) => { + const result = await callback(args); + return { + content: [ + { + type: "text", + text: message + ? typeof message === "function" + ? message(result) + : message + : JSON.stringify(result, null, 2), + }, + ], + structuredContent: result, + }; + }, + ); + }, + ); }, { // Optional server options diff --git a/src/toolkits/toolkits/client.ts b/src/toolkits/toolkits/client.ts index 29087577..06583554 100644 --- a/src/toolkits/toolkits/client.ts +++ b/src/toolkits/toolkits/client.ts @@ -19,6 +19,7 @@ import { spotifyClientToolkit } from "./spotify/client"; import { etsyClientToolkit } from "./etsy/client"; import { videoClientToolkit } from "./video/client"; import { twitterClientToolkit } from "./twitter/client"; +import { mcpClientToolkit } from "./mcp/client"; export type ClientToolkits = { [K in Toolkits]: ClientToolkit< @@ -42,6 +43,7 @@ export const clientToolkits: ClientToolkits = { [Toolkits.Etsy]: etsyClientToolkit, [Toolkits.Video]: videoClientToolkit, [Toolkits.Twitter]: twitterClientToolkit, + [Toolkits.MCP]: mcpClientToolkit, }; export function getClientToolkit( diff --git a/src/toolkits/toolkits/mcp/base.ts b/src/toolkits/toolkits/mcp/base.ts new file mode 100644 index 00000000..1b10dce0 --- /dev/null +++ b/src/toolkits/toolkits/mcp/base.ts @@ -0,0 +1,52 @@ +import { z } from "zod"; + +import type { ToolkitConfig } from "@/toolkits/types"; + +import type { McpTools } from "./tools"; + +const isHeadersJson = (value: string | undefined) => { + if (!value?.trim()) { + return true; + } + + try { + const parsed = JSON.parse(value) as unknown; + + return ( + typeof parsed === "object" && + parsed !== null && + !Array.isArray(parsed) && + Object.values(parsed).every( + (headerValue) => typeof headerValue === "string", + ) + ); + } catch { + return false; + } +}; + +export const mcpParameters = z.object({ + url: z.string().url(), + headers: z + .string() + .optional() + .refine(isHeadersJson, "Headers must be a JSON object with string values"), +}); + +export const parseMcpHeaders = ( + headers: string | undefined, +): Record | undefined => { + if (!headers?.trim()) { + return undefined; + } + + return JSON.parse(headers) as Record; +}; + +export const baseMcpToolkitConfig: ToolkitConfig< + McpTools, + typeof mcpParameters.shape +> = { + tools: {}, + parameters: mcpParameters, +}; diff --git a/src/toolkits/toolkits/mcp/client.tsx b/src/toolkits/toolkits/mcp/client.tsx new file mode 100644 index 00000000..f23d6c14 --- /dev/null +++ b/src/toolkits/toolkits/mcp/client.tsx @@ -0,0 +1,20 @@ +import { Plug } from "lucide-react"; + +import { createClientToolkit } from "@/toolkits/create-toolkit"; +import { ToolkitGroups } from "@/toolkits/types"; + +import { baseMcpToolkitConfig } from "./base"; +import { McpForm } from "./form"; + +export const mcpClientToolkit = createClientToolkit( + baseMcpToolkitConfig, + { + name: "MCP", + description: "Connect to any hosted MCP server", + icon: Plug, + form: McpForm, + type: ToolkitGroups.DataSource, + envVars: [], + }, + {}, +); diff --git a/src/toolkits/toolkits/mcp/form.tsx b/src/toolkits/toolkits/mcp/form.tsx new file mode 100644 index 00000000..523eba20 --- /dev/null +++ b/src/toolkits/toolkits/mcp/form.tsx @@ -0,0 +1,53 @@ +"use client"; + +import { Label } from "@/components/ui/label"; +import { Input } from "@/components/ui/input"; +import { Textarea } from "@/components/ui/textarea"; +import { VStack } from "@/components/ui/stack"; + +import type { z } from "zod"; +import type { mcpParameters } from "./base"; + +interface McpFormProps { + parameters: z.infer; + setParameters: (parameters: z.infer) => void; +} + +export const McpForm: React.FC = ({ + parameters, + setParameters, +}) => { + return ( + + + + + setParameters({ + ...parameters, + url: event.target.value, + }) + } + placeholder="https://example.com/sse" + /> + + + +