Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion src/app/(general)/_components/chat/messages/message-tool.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ const MessageToolComponent: React.FC<Props> = ({ toolInvocation }) => {

const { toolName } = toolInvocation;

const [server, tool] = toolName.split("_");
const separatorIndex = toolName.indexOf("_");
const server =
separatorIndex === -1 ? undefined : toolName.slice(0, separatorIndex);
const tool =
separatorIndex === -1 ? undefined : toolName.slice(separatorIndex + 1);

if (!server || !tool) {
return (
Expand All @@ -57,6 +61,16 @@ const MessageToolComponent: React.FC<Props> = ({ toolInvocation }) => {
const typedTool = tool as ServerToolkitNames[typeof typedServer];
const toolConfig = clientToolkit.tools[typedTool];

if (!toolConfig) {
return (
<GenericToolInvocation
clientToolkit={clientToolkit}
toolInvocation={toolInvocation}
toolName={tool}
/>
);
}

return (
<motion.div
initial={{
Expand Down Expand Up @@ -217,6 +231,67 @@ const MessageToolComponent: React.FC<Props> = ({ toolInvocation }) => {
);
};

const GenericToolInvocation: React.FC<{
clientToolkit: ReturnType<typeof getClientToolkit>;
toolInvocation: ToolInvocation;
toolName: string;
}> = ({ clientToolkit, toolInvocation, toolName }) => {
const argsDefined = toolInvocation.args !== undefined;

return (
<motion.div
initial={{
opacity: argsDefined ? 1 : 0,
y: argsDefined ? 0 : 20,
scale: argsDefined ? 1 : 0.95,
}}
animate={
!argsDefined
? {
opacity: 1,
y: 0,
scale: 1,
}
: undefined
}
transition={{ duration: 0.4, ease: "easeInOut" }}
>
<Card className="gap-0 overflow-hidden p-0">
<HStack className="border-b p-2">
<clientToolkit.icon className="size-4" />
{toolInvocation.state === "result" ? (
<span className="text-lg font-medium">
{clientToolkit.name} Toolkit
</span>
) : (
<AnimatedShinyText className="text-lg font-medium">
{clientToolkit.name} Toolkit
</AnimatedShinyText>
)}
{(toolInvocation.state === "call" ||
toolInvocation.state === "partial-call") && (
<Loader2 className="size-4 animate-spin opacity-60" />
)}
</HStack>
<div className="space-y-2 p-2">
<p className="text-muted-foreground text-sm font-medium">
{toolName}
</p>
<pre className="bg-muted max-h-80 w-full max-w-full overflow-auto rounded-md p-2 text-xs whitespace-pre-wrap">
{JSON.stringify(
toolInvocation.state === "result"
? toolInvocation.result
: (toolInvocation.args ?? {}),
null,
2,
)}
</pre>
</div>
</Card>
</motion.div>
);
};

const areEqual = (prevProps: Props, nextProps: Props): boolean => {
const { toolInvocation: prev } = prevProps;
const { toolInvocation: next } = nextProps;
Expand Down
224 changes: 163 additions & 61 deletions src/app/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
appendResponseMessages,
convertToCoreMessages,
createDataStream,
experimental_createMCPClient as createMCPClient,
smoothStream,
tool,
} from "ai";
Expand Down Expand Up @@ -35,9 +36,13 @@ import type { Chat } from "@prisma/client";
import { openai } from "@ai-sdk/openai";
import { getServerToolkit } from "@/toolkits/toolkits/server";
import { languageModels } from "@/ai/language";
import { Toolkits } from "@/toolkits/toolkits/shared";
import { mcpParameters, parseMcpHeaders } from "@/toolkits/toolkits/mcp/base";

export const maxDuration = 60;

type MCPClient = Awaited<ReturnType<typeof createMCPClient>>;

let globalStreamContext: ResumableStreamContext | null = null;

function getStreamContext() {
Expand All @@ -62,6 +67,14 @@ function getStreamContext() {

export async function POST(request: Request) {
let requestBody: PostRequestBody;
const mcpClients: MCPClient[] = [];
const closeMcpClients = async () => {
await Promise.allSettled(
mcpClients.map(async (client) => {
await client.close();
}),
);
};

try {
requestBody = postRequestBodySchema.parse(await request.json());
Expand Down Expand Up @@ -156,6 +169,10 @@ export async function POST(request: Request) {

const toolkitTools = await Promise.all(
toolkits.map(async ({ id, parameters }) => {
if (id === Toolkits.MCP) {
return createMcpToolkitTools(parameters, mcpClients);
}

const toolkit = getServerToolkit(id);
const tools = await toolkit.tools(parameters);
return Object.keys(tools).reduce(
Expand Down Expand Up @@ -215,7 +232,12 @@ export async function POST(request: Request) {

// Collect toolkit system prompts
const toolkitSystemPrompts = await Promise.all(
toolkits.map(async ({ id }) => {
toolkits.map(async ({ id, parameters }) => {
if (id === Toolkits.MCP) {
const { url } = mcpParameters.parse(parameters);
return `You have access to a user-configured hosted MCP server at ${url}. Its tools are discovered dynamically and are prefixed with "${Toolkits.MCP}_". Use them when they match the user's request.`;
}

const toolkit = getServerToolkit(id);
return toolkit.systemPrompt;
}),
Expand Down Expand Up @@ -256,6 +278,7 @@ export async function POST(request: Request) {
experimental_generateMessageId: generateUUID,
onError: (error) => {
console.error("Stream error occurred:", error);
void closeMcpClients();

// Check if it's a 402 error and log it specifically
if (error && typeof error === "object") {
Expand All @@ -279,71 +302,75 @@ export async function POST(request: Request) {
// Don't throw - just let the stream end naturally after sending error data
},
onFinish: async ({ response }) => {
// Get the actual model used from OpenRouter's response
const [provider, modelId] = response.modelId.split("/");

// Try to find the model in our list first
const model = languageModels.find(
(model) =>
model.provider === provider && model.modelId === modelId,
);

// Create model info from OpenRouter's response if not in our list
const modelInfo = model ?? {
name: `${provider}/${modelId}`, // Format nicely for display
provider: provider ?? "unknown",
modelId: modelId ?? "unknown",
};

// Write the model annotation
dataStream.writeMessageAnnotation({
type: "model",
model: modelInfo,
});
try {
// Get the actual model used from OpenRouter's response
const [provider, modelId] = response.modelId.split("/");

// Try to find the model in our list first
const model = languageModels.find(
(model) =>
model.provider === provider && model.modelId === modelId,
);

// Create model info from OpenRouter's response if not in our list
const modelInfo = model ?? {
name: `${provider}/${modelId}`, // Format nicely for display
provider: provider ?? "unknown",
modelId: modelId ?? "unknown",
};

// Write the model annotation
dataStream.writeMessageAnnotation({
type: "model",
model: modelInfo,
});

// Send modelId as message annotation
if (session.user?.id) {
try {
const assistantId = getTrailingMessageId({
messages: response.messages.filter(
(message) => message.role === "assistant",
),
});

// Send modelId as message annotation
if (session.user?.id) {
try {
const assistantId = getTrailingMessageId({
messages: response.messages.filter(
(message) => message.role === "assistant",
),
});

if (!assistantId) {
throw new Error("No assistant message found!");
}
if (!assistantId) {
throw new Error("No assistant message found!");
}

const [, assistantMessage] = appendResponseMessages({
messages: [message],
responseMessages: response.messages,
});
const [, assistantMessage] = appendResponseMessages({
messages: [message],
responseMessages: response.messages,
});

if (!assistantMessage) {
throw new Error("No assistant message found!");
if (!assistantMessage) {
throw new Error("No assistant message found!");
}

await api.messages.createMessage({
chatId: id,
id: assistantId,
role: "assistant",
parts: assistantMessage.parts ?? [],
attachments:
assistantMessage.experimental_attachments?.map(
(attachment) => ({
url: attachment.url,
name: attachment.name ?? "",
contentType: attachment.contentType as
| "image/png"
| "image/jpg"
| "image/jpeg",
}),
) ?? [],
modelId: response.modelId, // Use the actual model from OpenRouter's response
});
} catch (error) {
console.error(error);
}

await api.messages.createMessage({
chatId: id,
id: assistantId,
role: "assistant",
parts: assistantMessage.parts ?? [],
attachments:
assistantMessage.experimental_attachments?.map(
(attachment) => ({
url: attachment.url,
name: attachment.name ?? "",
contentType: attachment.contentType as
| "image/png"
| "image/jpg"
| "image/jpeg",
}),
) ?? [],
modelId: response.modelId, // Use the actual model from OpenRouter's response
});
} catch (error) {
console.error(error);
}
} finally {
await closeMcpClients();
}
},
tools: {
Expand Down Expand Up @@ -377,6 +404,8 @@ export async function POST(request: Request) {
return new Response(stream);
}
} catch (error) {
await closeMcpClients();

if (error instanceof ChatSDKError) {
return error.toResponse();
}
Expand All @@ -385,6 +414,79 @@ export async function POST(request: Request) {
}
}

async function createMcpToolkitTools(
parameters: Record<string, unknown>,
mcpClients: MCPClient[],
): Promise<Record<string, Tool>> {
const { url, headers } = mcpParameters.parse(parameters);
const mcpClient = await createMCPClient({
transport: {
type: "sse",
url,
headers: parseMcpHeaders(headers),
},
});
mcpClients.push(mcpClient);

const mcpTools = await mcpClient.tools();

return Object.entries(mcpTools).reduce(
(acc, [toolName, mcpTool]) => {
acc[`${Toolkits.MCP}_${toolName}`] = {
...mcpTool,
execute: async (args, options) => {
const result = await mcpTool.execute(args, options);

return {
result,
message: getMcpToolResultMessage(result),
};
},
};
return acc;
},
{} as Record<string, Tool>,
);
}

function getMcpToolResultMessage(result: unknown) {
if (hasMcpContent(result)) {
const text = result.content
.map((content) => (isMcpTextContent(content) ? content.text : null))
.filter((content): content is string => content !== null)
.join("\n")
.trim();

if (text.length > 0) {
return text;
}
}

return "The MCP tool returned structured output. Use it to answer the user's request.";
}

function hasMcpContent(result: unknown): result is { content: unknown[] } {
return (
typeof result === "object" &&
result !== null &&
"content" in result &&
Array.isArray(result.content)
);
}

function isMcpTextContent(
content: unknown,
): content is { type: "text"; text: string } {
return (
typeof content === "object" &&
content !== null &&
"type" in content &&
content.type === "text" &&
"text" in content &&
typeof content.text === "string"
);
}

async function generateTitleFromUserMessage(message: UIMessage) {
const { text: title } = await generateText("openai/gpt-4o-mini", {
system: `\n
Expand Down
Loading