diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml index c821e5a..94b394b 100644 --- a/.github/workflows/semgrep.yml +++ b/.github/workflows/semgrep.yml @@ -1,13 +1,12 @@ - on: pull_request: {} workflow_dispatch: {} - push: + push: branches: - main - master schedule: - - cron: '0 0 * * *' + - cron: "0 0 * * *" name: Semgrep config jobs: semgrep: diff --git a/bun.lock b/bun.lock new file mode 100644 index 0000000..5120817 --- /dev/null +++ b/bun.lock @@ -0,0 +1,84 @@ +{ + "lockfileVersion": 1, + "configVersion": 1, + "workspaces": { + "": { + "name": "@cloudflare/ai-utils", + "devDependencies": { + "@cloudflare/workers-types": "^4.20260415.1", + "@types/json-schema": "^7.0.15", + "@types/node": "^20.14.8", + "esbuild": "^0.21.5", + "prettier": "^3.3.2", + "typescript": "^5.5.2", + "yaml": "^2.4.5", + "zod": "^3.23.8", + }, + }, + }, + "packages": { + "@cloudflare/workers-types": ["@cloudflare/workers-types@4.20260415.1", "", {}, "sha512-9sEq9cZzr4s075U/TfjvdSmiX+u2NMOAIcFcCfd24FDtPfR7Iw3SbuQxkcgtpx/Bvg0au9PmQ0ZJfBaIitG0gw=="], + + "@esbuild/aix-ppc64": ["@esbuild/aix-ppc64@0.21.5", "", { "os": "aix", "cpu": "ppc64" }, "sha512-1SDgH6ZSPTlggy1yI6+Dbkiz8xzpHJEVAlF/AM1tHPLsf5STom9rwtjE4hKAF20FfXXNTFqEYXyJNWh1GiZedQ=="], + + "@esbuild/android-arm": ["@esbuild/android-arm@0.21.5", "", { "os": "android", "cpu": "arm" }, "sha512-vCPvzSjpPHEi1siZdlvAlsPxXl7WbOVUBBAowWug4rJHb68Ox8KualB+1ocNvT5fjv6wpkX6o/iEpbDrf68zcg=="], + + "@esbuild/android-arm64": ["@esbuild/android-arm64@0.21.5", "", { "os": "android", "cpu": "arm64" }, "sha512-c0uX9VAUBQ7dTDCjq+wdyGLowMdtR/GoC2U5IYk/7D1H1JYC0qseD7+11iMP2mRLN9RcCMRcjC4YMclCzGwS/A=="], + + "@esbuild/android-x64": ["@esbuild/android-x64@0.21.5", "", { "os": "android", "cpu": "x64" }, "sha512-D7aPRUUNHRBwHxzxRvp856rjUHRFW1SdQATKXH2hqA0kAZb1hKmi02OpYRacl0TxIGz/ZmXWlbZgjwWYaCakTA=="], + + "@esbuild/darwin-arm64": ["@esbuild/darwin-arm64@0.21.5", "", { "os": "darwin", "cpu": "arm64" }, "sha512-DwqXqZyuk5AiWWf3UfLiRDJ5EDd49zg6O9wclZ7kUMv2WRFr4HKjXp/5t8JZ11QbQfUS6/cRCKGwYhtNAY88kQ=="], + + "@esbuild/darwin-x64": ["@esbuild/darwin-x64@0.21.5", "", { "os": "darwin", "cpu": "x64" }, "sha512-se/JjF8NlmKVG4kNIuyWMV/22ZaerB+qaSi5MdrXtd6R08kvs2qCN4C09miupktDitvh8jRFflwGFBQcxZRjbw=="], + + "@esbuild/freebsd-arm64": ["@esbuild/freebsd-arm64@0.21.5", "", { "os": "freebsd", "cpu": "arm64" }, "sha512-5JcRxxRDUJLX8JXp/wcBCy3pENnCgBR9bN6JsY4OmhfUtIHe3ZW0mawA7+RDAcMLrMIZaf03NlQiX9DGyB8h4g=="], + + "@esbuild/freebsd-x64": ["@esbuild/freebsd-x64@0.21.5", "", { "os": "freebsd", "cpu": "x64" }, "sha512-J95kNBj1zkbMXtHVH29bBriQygMXqoVQOQYA+ISs0/2l3T9/kj42ow2mpqerRBxDJnmkUDCaQT/dfNXWX/ZZCQ=="], + + "@esbuild/linux-arm": ["@esbuild/linux-arm@0.21.5", "", { "os": "linux", "cpu": "arm" }, "sha512-bPb5AHZtbeNGjCKVZ9UGqGwo8EUu4cLq68E95A53KlxAPRmUyYv2D6F0uUI65XisGOL1hBP5mTronbgo+0bFcA=="], + + "@esbuild/linux-arm64": ["@esbuild/linux-arm64@0.21.5", "", { "os": "linux", "cpu": "arm64" }, "sha512-ibKvmyYzKsBeX8d8I7MH/TMfWDXBF3db4qM6sy+7re0YXya+K1cem3on9XgdT2EQGMu4hQyZhan7TeQ8XkGp4Q=="], + + "@esbuild/linux-ia32": ["@esbuild/linux-ia32@0.21.5", "", { "os": "linux", "cpu": "ia32" }, "sha512-YvjXDqLRqPDl2dvRODYmmhz4rPeVKYvppfGYKSNGdyZkA01046pLWyRKKI3ax8fbJoK5QbxblURkwK/MWY18Tg=="], + + "@esbuild/linux-loong64": ["@esbuild/linux-loong64@0.21.5", "", { "os": "linux", "cpu": "none" }, "sha512-uHf1BmMG8qEvzdrzAqg2SIG/02+4/DHB6a9Kbya0XDvwDEKCoC8ZRWI5JJvNdUjtciBGFQ5PuBlpEOXQj+JQSg=="], + + "@esbuild/linux-mips64el": ["@esbuild/linux-mips64el@0.21.5", "", { "os": "linux", "cpu": "none" }, "sha512-IajOmO+KJK23bj52dFSNCMsz1QP1DqM6cwLUv3W1QwyxkyIWecfafnI555fvSGqEKwjMXVLokcV5ygHW5b3Jbg=="], + + "@esbuild/linux-ppc64": ["@esbuild/linux-ppc64@0.21.5", "", { "os": "linux", "cpu": "ppc64" }, "sha512-1hHV/Z4OEfMwpLO8rp7CvlhBDnjsC3CttJXIhBi+5Aj5r+MBvy4egg7wCbe//hSsT+RvDAG7s81tAvpL2XAE4w=="], + + "@esbuild/linux-riscv64": ["@esbuild/linux-riscv64@0.21.5", "", { "os": "linux", "cpu": "none" }, "sha512-2HdXDMd9GMgTGrPWnJzP2ALSokE/0O5HhTUvWIbD3YdjME8JwvSCnNGBnTThKGEB91OZhzrJ4qIIxk/SBmyDDA=="], + + "@esbuild/linux-s390x": ["@esbuild/linux-s390x@0.21.5", "", { "os": "linux", "cpu": "s390x" }, "sha512-zus5sxzqBJD3eXxwvjN1yQkRepANgxE9lgOW2qLnmr8ikMTphkjgXu1HR01K4FJg8h1kEEDAqDcZQtbrRnB41A=="], + + "@esbuild/linux-x64": ["@esbuild/linux-x64@0.21.5", "", { "os": "linux", "cpu": "x64" }, "sha512-1rYdTpyv03iycF1+BhzrzQJCdOuAOtaqHTWJZCWvijKD2N5Xu0TtVC8/+1faWqcP9iBCWOmjmhoH94dH82BxPQ=="], + + "@esbuild/netbsd-x64": ["@esbuild/netbsd-x64@0.21.5", "", { "os": "none", "cpu": "x64" }, "sha512-Woi2MXzXjMULccIwMnLciyZH4nCIMpWQAs049KEeMvOcNADVxo0UBIQPfSmxB3CWKedngg7sWZdLvLczpe0tLg=="], + + "@esbuild/openbsd-x64": ["@esbuild/openbsd-x64@0.21.5", "", { "os": "openbsd", "cpu": "x64" }, "sha512-HLNNw99xsvx12lFBUwoT8EVCsSvRNDVxNpjZ7bPn947b8gJPzeHWyNVhFsaerc0n3TsbOINvRP2byTZ5LKezow=="], + + "@esbuild/sunos-x64": ["@esbuild/sunos-x64@0.21.5", "", { "os": "sunos", "cpu": "x64" }, "sha512-6+gjmFpfy0BHU5Tpptkuh8+uw3mnrvgs+dSPQXQOv3ekbordwnzTVEb4qnIvQcYXq6gzkyTnoZ9dZG+D4garKg=="], + + "@esbuild/win32-arm64": ["@esbuild/win32-arm64@0.21.5", "", { "os": "win32", "cpu": "arm64" }, "sha512-Z0gOTd75VvXqyq7nsl93zwahcTROgqvuAcYDUr+vOv8uHhNSKROyU961kgtCD1e95IqPKSQKH7tBTslnS3tA8A=="], + + "@esbuild/win32-ia32": ["@esbuild/win32-ia32@0.21.5", "", { "os": "win32", "cpu": "ia32" }, "sha512-SWXFF1CL2RVNMaVs+BBClwtfZSvDgtL//G/smwAc5oVK/UPu2Gu9tIaRgFmYFFKrmg3SyAjSrElf0TiJ1v8fYA=="], + + "@esbuild/win32-x64": ["@esbuild/win32-x64@0.21.5", "", { "os": "win32", "cpu": "x64" }, "sha512-tQd/1efJuzPC6rCFwEvLtci/xNFcTZknmXs98FYDfGE4wP9ClFV98nyKrzJKVPMhdDnjzLhdUyMX4PsQAPjwIw=="], + + "@types/json-schema": ["@types/json-schema@7.0.15", "", {}, "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA=="], + + "@types/node": ["@types/node@20.19.39", "", { "dependencies": { "undici-types": "~6.21.0" } }, "sha512-orrrD74MBUyK8jOAD/r0+lfa1I2MO6I+vAkmAWzMYbCcgrN4lCrmK52gRFQq/JRxfYPfonkr4b0jcY7Olqdqbw=="], + + "esbuild": ["esbuild@0.21.5", "", { "optionalDependencies": { "@esbuild/aix-ppc64": "0.21.5", "@esbuild/android-arm": "0.21.5", "@esbuild/android-arm64": "0.21.5", "@esbuild/android-x64": "0.21.5", "@esbuild/darwin-arm64": "0.21.5", "@esbuild/darwin-x64": "0.21.5", "@esbuild/freebsd-arm64": "0.21.5", "@esbuild/freebsd-x64": "0.21.5", "@esbuild/linux-arm": "0.21.5", "@esbuild/linux-arm64": "0.21.5", "@esbuild/linux-ia32": "0.21.5", "@esbuild/linux-loong64": "0.21.5", "@esbuild/linux-mips64el": "0.21.5", "@esbuild/linux-ppc64": "0.21.5", "@esbuild/linux-riscv64": "0.21.5", "@esbuild/linux-s390x": "0.21.5", "@esbuild/linux-x64": "0.21.5", "@esbuild/netbsd-x64": "0.21.5", "@esbuild/openbsd-x64": "0.21.5", "@esbuild/sunos-x64": "0.21.5", "@esbuild/win32-arm64": "0.21.5", "@esbuild/win32-ia32": "0.21.5", "@esbuild/win32-x64": "0.21.5" }, "bin": { "esbuild": "bin/esbuild" } }, "sha512-mg3OPMV4hXywwpoDxu3Qda5xCKQi+vCTZq8S9J/EpkhB2HzKXq4SNFZE3+NK93JYxc8VMSep+lOUSC/RVKaBqw=="], + + "prettier": ["prettier@3.8.3", "", { "bin": { "prettier": "bin/prettier.cjs" } }, "sha512-7igPTM53cGHMW8xWuVTydi2KO233VFiTNyF5hLJqpilHfmn8C8gPf+PS7dUT64YcXFbiMGZxS9pCSxL/Dxm/Jw=="], + + "typescript": ["typescript@5.9.3", "", { "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" } }, "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw=="], + + "undici-types": ["undici-types@6.21.0", "", {}, "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ=="], + + "yaml": ["yaml@2.8.3", "", { "bin": { "yaml": "bin.mjs" } }, "sha512-AvbaCLOO2Otw/lW5bmh9d/WEdcDFdQp2Z2ZUH3pX9U2ihyUY0nvLv7J6TrWowklRGPYbB/IuIMfYgxaCPg5Bpg=="], + + "zod": ["zod@3.25.76", "", {}, "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ=="], + } +} diff --git a/package.json b/package.json index d16d414..745de1e 100644 --- a/package.json +++ b/package.json @@ -16,7 +16,7 @@ "types": "dist/index.d.ts", "dependencies": {}, "devDependencies": { - "@cloudflare/workers-types": "^4.20240620.0", + "@cloudflare/workers-types": "^4.20260415.1", "@types/json-schema": "^7.0.15", "@types/node": "^20.14.8", "esbuild": "^0.21.5", diff --git a/src/runWithTools.ts b/src/runWithTools.ts index ff38249..64d0649 100644 --- a/src/runWithTools.ts +++ b/src/runWithTools.ts @@ -2,12 +2,81 @@ import { Logger } from "./logger"; import { validateArgsWithZod } from "./utils"; import { Ai, - AiTextGenerationInput, + AiModels, AiTextGenerationOutput, - BaseAiTextGenerationModels, + ChatCompletionsCommonOptions, RoleScopedChatInput, } from "@cloudflare/workers-types"; -import { AiTextGenerationToolInputWithFunction } from "./types"; +import { + AiTextGenerationToolInputWithFunction, + BaseAiTextGenerationModels, +} from "./types"; + +/** + * Extracts advanced chat input configuration (excluding messages and tools) + * from the model's input type, supporting both legacy AiTextGenerationInput + * and newer ChatCompletionsMessagesInput formats. + */ +type ChatInputConfig = Omit< + AiModels[Model]["inputs"], + "messages" | "tools" +>; + +type NormalizedToolCall = { + name: string; + arguments: unknown; +}; +/** + * Extracts tool calls from either Mistral-like or OpenAI-like response formats. + */ +function extractToolCalls(response: unknown): NormalizedToolCall[] { + if (!response || typeof response !== "object") { + return []; + } + + const res = response as Record; + + // Mistral-like format: { tool_calls: [{ name, arguments }] } + if (Array.isArray(res.tool_calls)) { + return res.tool_calls + .filter(Boolean) + .map((tc: { name: string; arguments: unknown }) => ({ + name: tc.name, + arguments: tc.arguments, + })); + } + + // OpenAI-like format: { choices: [{ message: { tool_calls: [{ function: { name, arguments } }] } }] } + if (Array.isArray(res.choices)) { + const choices = res.choices as Array<{ + message?: { + tool_calls?: Array<{ + function?: { name: string; arguments: string }; + }>; + }; + }>; + for (const choice of choices) { + if (Array.isArray(choice.message?.tool_calls)) { + return choice.message.tool_calls.filter(Boolean).map((tc) => { + let args: unknown = tc.function?.arguments; + if (typeof args === "string") { + try { + args = JSON.parse(args); + } catch { + // Keep as string if not valid JSON + } + } + return { + name: tc.function?.name ?? "", + arguments: args, + }; + }); + } + } + } + + return []; +} /** * Runs a set of tools on a given input and returns the final response in the same format as the AI.run call. @@ -23,14 +92,15 @@ import { AiTextGenerationToolInputWithFunction } from "./types"; * @param {boolean} [config.strictValidation=false] - Whether to perform strict validation (using zod) of the arguments passed to the tools. * @param {boolean} [config.verbose=false] - Whether to enable verbose logging. * @param {(tools: AiTextGenerationToolInputWithFunction[], ai: Ai, model: BaseAiTextGenerationModels, messages: RoleScopedChatInput[]) => Promise} [config.trimFunction] - Use a trim function to trim down the number of tools given to the AI for a given task. You can also use this alongside `autoTrimTools`, which uses an extra AI.run call to cut down on the input tokens of the tool call based on the tool's names. + * @param {ChatCompletionsCommonOptions} [chatInputConfig] - Advanced AI.run configuration options (e.g. temperature, max_tokens, top_p, seed, response_format). Inferred from the model's input type. * - * @returns {Promise} The final response in the same format as the AI.run call. + * @returns {Promise} The final response in the same format as the AI.run call, with the output type inferred from the provided model. */ -export const runWithTools = async ( +export const runWithTools = async ( /** The AI instance to use for the run. */ ai: Ai, /** The function calling model to use for the run. We recommend using `@hf/nousresearch/hermes-2-pro-mistral-7b`, `llama-3` or equivalent model that's suited for function calling. */ - model: BaseAiTextGenerationModels, + model: Model, /** The input for the runWithTools call. */ input: { /** The messages to be sent to the AI. */ @@ -57,7 +127,9 @@ export const runWithTools = async ( messages: RoleScopedChatInput[], ) => Promise; } = {}, -): Promise => { + /** Advanced AI.run configuration options (e.g. temperature, max_tokens, top_p, seed, response_format). Inferred from the model's input type. */ + chatInputConfig: ChatInputConfig = {} as ChatInputConfig, +): Promise => { // Destructure config with default values const { streamFinalResponse = false, @@ -112,28 +184,23 @@ export const runWithTools = async ( maxRecursiveToolRuns, }: { ai: Ai; - model: BaseAiTextGenerationModels; + model: Model; messages: RoleScopedChatInput[]; streamFinalResponse: boolean; maxRecursiveToolRuns: number; - }): Promise { + }): Promise { try { Logger.info("Starting AI.run call"); Logger.info("Messages", JSON.stringify(messages, null, 2)); Logger.info(`Only using ${input.tools.length} tools`); - const response = (await ai.run(model, { - messages: messages, + const rawResponse = await ai.run(model, { + messages, stream: false, - tools: tools, - })) as { - response?: string; - tool_calls?: { - name: string; - arguments: unknown; - }[]; - }; + ...(chatInputConfig as ChatCompletionsCommonOptions), + tools, + }); const chars = JSON.stringify(messages).length + @@ -143,9 +210,10 @@ export const runWithTools = async ( `Number of characters for the first AI.run call: ${totalCharacters}`, ); - Logger.info("AI.run call completed", response); + Logger.info("AI.run call completed", rawResponse); - tool_calls = response.tool_calls?.filter(Boolean) ?? []; + // Extract tool_calls from either Mistral-like or OpenAI-like response format + tool_calls = extractToolCalls(rawResponse); const toolCallPromises = tool_calls.map(async (toolCall) => { const toolCallObjectJson = toolCall; @@ -197,7 +265,6 @@ export const runWithTools = async ( messages.push({ role: "tool", content: JSON.stringify(result), - // @ts-expect-error workerd types name: selectedTool.name, }); } catch (error) { @@ -205,7 +272,6 @@ export const runWithTools = async ( messages.push({ role: "tool", content: `Error executing tool ${selectedTool.name}: ${(error as Error).message}`, - // @ts-expect-error workerd types name: selectedTool.name, }); } @@ -213,7 +279,7 @@ export const runWithTools = async ( Logger.error( `Function for tool ${toolCallObjectJson.name} is undefined`, ); - return response + return rawResponse as AiTextGenerationOutput; } }); @@ -235,8 +301,9 @@ export const runWithTools = async ( ); const finalResponse = await ai.run(model, { - messages: messages, + messages, stream: streamFinalResponse, + ...(chatInputConfig as ChatCompletionsCommonOptions), }); totalCharacters += JSON.stringify(messages).length; Logger.info( diff --git a/src/types/index.ts b/src/types/index.ts index fd6ddcf..8ca1ebe 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -1,4 +1,4 @@ -import { AiTextGenerationToolInput } from "@cloudflare/workers-types"; +import { AiModels, AiTextGenerationToolInput } from "@cloudflare/workers-types"; import { JSONSchema7 } from "json-schema"; export type UppercaseHttpMethod = @@ -91,3 +91,5 @@ export function tool( ): ToolsSchema { return tool; } + +export type BaseAiTextGenerationModels = keyof AiModels; diff --git a/src/utils.ts b/src/utils.ts index 789973b..94090d4 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -3,12 +3,11 @@ import { OpenAPIV3 } from "./types/openapi-schema"; import { JSONSchema7 } from "json-schema"; import { ZodTypeAny, z } from "zod"; import { Logger } from "./logger"; -import { AiTextGenerationToolInputWithFunction } from "./types"; import { - Ai, + AiTextGenerationToolInputWithFunction, BaseAiTextGenerationModels, - RoleScopedChatInput, -} from "@cloudflare/workers-types"; +} from "./types"; +import { Ai, RoleScopedChatInput } from "@cloudflare/workers-types"; export async function fetchSpec( spec: string,