Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/semgrep.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@

on:
pull_request: {}
workflow_dispatch: {}
push:
push:
branches:
- main
- master
schedule:
- cron: '0 0 * * *'
- cron: "0 0 * * *"
name: Semgrep config
jobs:
semgrep:
Expand Down
84 changes: 84 additions & 0 deletions bun.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"types": "dist/index.d.ts",
"dependencies": {},
"devDependencies": {
"@cloudflare/workers-types": "^4.20240620.0",
"@cloudflare/workers-types": "^4.20260415.1",
"@types/json-schema": "^7.0.15",
"@types/node": "^20.14.8",
"esbuild": "^0.21.5",
Expand Down
117 changes: 92 additions & 25 deletions src/runWithTools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,81 @@ import { Logger } from "./logger";
import { validateArgsWithZod } from "./utils";
import {
Ai,
AiTextGenerationInput,
AiModels,
AiTextGenerationOutput,
BaseAiTextGenerationModels,
ChatCompletionsCommonOptions,
RoleScopedChatInput,
} from "@cloudflare/workers-types";
import { AiTextGenerationToolInputWithFunction } from "./types";
import {
AiTextGenerationToolInputWithFunction,
BaseAiTextGenerationModels,
} from "./types";

/**
* Extracts advanced chat input configuration (excluding messages and tools)
* from the model's input type, supporting both legacy AiTextGenerationInput
* and newer ChatCompletionsMessagesInput formats.
*/
type ChatInputConfig<Model extends keyof AiModels> = Omit<
AiModels[Model]["inputs"],
"messages" | "tools"
>;

type NormalizedToolCall = {
name: string;
arguments: unknown;
};
/**
* Extracts tool calls from either Mistral-like or OpenAI-like response formats.
*/
function extractToolCalls(response: unknown): NormalizedToolCall[] {
if (!response || typeof response !== "object") {
return [];
}

const res = response as Record<string, unknown>;

// Mistral-like format: { tool_calls: [{ name, arguments }] }
if (Array.isArray(res.tool_calls)) {
return res.tool_calls
.filter(Boolean)
.map((tc: { name: string; arguments: unknown }) => ({
name: tc.name,
arguments: tc.arguments,
}));
}

// OpenAI-like format: { choices: [{ message: { tool_calls: [{ function: { name, arguments } }] } }] }
if (Array.isArray(res.choices)) {
const choices = res.choices as Array<{
message?: {
tool_calls?: Array<{
function?: { name: string; arguments: string };
}>;
};
}>;
for (const choice of choices) {
if (Array.isArray(choice.message?.tool_calls)) {
return choice.message.tool_calls.filter(Boolean).map((tc) => {
let args: unknown = tc.function?.arguments;
if (typeof args === "string") {
try {
args = JSON.parse(args);
} catch {
// Keep as string if not valid JSON
}
}
return {
name: tc.function?.name ?? "",
arguments: args,
};
});
}
}
}

return [];
}

/**
* Runs a set of tools on a given input and returns the final response in the same format as the AI.run call.
Expand All @@ -23,14 +92,15 @@ import { AiTextGenerationToolInputWithFunction } from "./types";
* @param {boolean} [config.strictValidation=false] - Whether to perform strict validation (using zod) of the arguments passed to the tools.
* @param {boolean} [config.verbose=false] - Whether to enable verbose logging.
* @param {(tools: AiTextGenerationToolInputWithFunction[], ai: Ai, model: BaseAiTextGenerationModels, messages: RoleScopedChatInput[]) => Promise<AiTextGenerationToolInputWithFunction[]>} [config.trimFunction] - Use a trim function to trim down the number of tools given to the AI for a given task. You can also use this alongside `autoTrimTools`, which uses an extra AI.run call to cut down on the input tokens of the tool call based on the tool's names.
* @param {ChatCompletionsCommonOptions} [chatInputConfig] - Advanced AI.run configuration options (e.g. temperature, max_tokens, top_p, seed, response_format). Inferred from the model's input type.
*
* @returns {Promise<AiTextGenerationOutput>} The final response in the same format as the AI.run call.
* @returns {Promise<AiModels[Model]["postProcessedOutputs"]>} The final response in the same format as the AI.run call, with the output type inferred from the provided model.
*/
export const runWithTools = async (
export const runWithTools = async <Model extends BaseAiTextGenerationModels>(
/** The AI instance to use for the run. */
ai: Ai,
/** The function calling model to use for the run. We recommend using `@hf/nousresearch/hermes-2-pro-mistral-7b`, `llama-3` or equivalent model that's suited for function calling. */
model: BaseAiTextGenerationModels,
model: Model,
/** The input for the runWithTools call. */
input: {
/** The messages to be sent to the AI. */
Expand All @@ -57,7 +127,9 @@ export const runWithTools = async (
messages: RoleScopedChatInput[],
) => Promise<AiTextGenerationToolInputWithFunction[]>;
} = {},
): Promise<AiTextGenerationOutput> => {
/** Advanced AI.run configuration options (e.g. temperature, max_tokens, top_p, seed, response_format). Inferred from the model's input type. */
chatInputConfig: ChatInputConfig<Model> = {} as ChatInputConfig<Model>,
): Promise<AiModels[Model]["postProcessedOutputs"]> => {
// Destructure config with default values
const {
streamFinalResponse = false,
Expand Down Expand Up @@ -112,28 +184,23 @@ export const runWithTools = async (
maxRecursiveToolRuns,
}: {
ai: Ai;
model: BaseAiTextGenerationModels;
model: Model;
messages: RoleScopedChatInput[];
streamFinalResponse: boolean;
maxRecursiveToolRuns: number;
}): Promise<AiTextGenerationOutput> {
}): Promise<AiModels[Model]["postProcessedOutputs"]> {
try {
Logger.info("Starting AI.run call");
Logger.info("Messages", JSON.stringify(messages, null, 2));

Logger.info(`Only using ${input.tools.length} tools`);

const response = (await ai.run(model, {
messages: messages,
const rawResponse = await ai.run(model, {
messages,
stream: false,
tools: tools,
})) as {
response?: string;
tool_calls?: {
name: string;
arguments: unknown;
}[];
};
...(chatInputConfig as ChatCompletionsCommonOptions),
tools,
});

const chars =
JSON.stringify(messages).length +
Expand All @@ -143,9 +210,10 @@ export const runWithTools = async (
`Number of characters for the first AI.run call: ${totalCharacters}`,
);

Logger.info("AI.run call completed", response);
Logger.info("AI.run call completed", rawResponse);

tool_calls = response.tool_calls?.filter(Boolean) ?? [];
// Extract tool_calls from either Mistral-like or OpenAI-like response format
tool_calls = extractToolCalls(rawResponse);

const toolCallPromises = tool_calls.map(async (toolCall) => {
const toolCallObjectJson = toolCall;
Expand Down Expand Up @@ -197,23 +265,21 @@ export const runWithTools = async (
messages.push({
role: "tool",
content: JSON.stringify(result),
// @ts-expect-error workerd types
name: selectedTool.name,
});
} catch (error) {
Logger.error(`Error executing tool ${selectedTool.name}:`, error);
messages.push({
role: "tool",
content: `Error executing tool ${selectedTool.name}: ${(error as Error).message}`,
// @ts-expect-error workerd types
name: selectedTool.name,
});
}
} else {
Logger.error(
`Function for tool ${toolCallObjectJson.name} is undefined`,
);
return response
return rawResponse as AiTextGenerationOutput;
}
});

Expand All @@ -235,8 +301,9 @@ export const runWithTools = async (
);

const finalResponse = await ai.run(model, {
messages: messages,
messages,
stream: streamFinalResponse,
...(chatInputConfig as ChatCompletionsCommonOptions),
});
totalCharacters += JSON.stringify(messages).length;
Logger.info(
Expand Down
4 changes: 3 additions & 1 deletion src/types/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AiTextGenerationToolInput } from "@cloudflare/workers-types";
import { AiModels, AiTextGenerationToolInput } from "@cloudflare/workers-types";
import { JSONSchema7 } from "json-schema";

export type UppercaseHttpMethod =
Expand Down Expand Up @@ -91,3 +91,5 @@ export function tool<T extends JSONSchema7>(
): ToolsSchema<T> {
return tool;
}

export type BaseAiTextGenerationModels = keyof AiModels;
7 changes: 3 additions & 4 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ import { OpenAPIV3 } from "./types/openapi-schema";
import { JSONSchema7 } from "json-schema";
import { ZodTypeAny, z } from "zod";
import { Logger } from "./logger";
import { AiTextGenerationToolInputWithFunction } from "./types";
import {
Ai,
AiTextGenerationToolInputWithFunction,
BaseAiTextGenerationModels,
RoleScopedChatInput,
} from "@cloudflare/workers-types";
} from "./types";
import { Ai, RoleScopedChatInput } from "@cloudflare/workers-types";

export async function fetchSpec(
spec: string,
Expand Down