diff --git a/package.json b/package.json index d16d414..bdebe66 100644 --- a/package.json +++ b/package.json @@ -14,9 +14,8 @@ ], "main": "dist/index.js", "types": "dist/index.d.ts", - "dependencies": {}, "devDependencies": { - "@cloudflare/workers-types": "^4.20240620.0", + "@cloudflare/workers-types": "^4.20251003.0", "@types/json-schema": "^7.0.15", "@types/node": "^20.14.8", "esbuild": "^0.21.5", diff --git a/src/runWithTools.ts b/src/runWithTools.ts index ff38249..31407b0 100644 --- a/src/runWithTools.ts +++ b/src/runWithTools.ts @@ -4,25 +4,25 @@ import { Ai, AiTextGenerationInput, AiTextGenerationOutput, - BaseAiTextGenerationModels, RoleScopedChatInput, } from "@cloudflare/workers-types"; -import { AiTextGenerationToolInputWithFunction } from "./types"; +import { AiTextGenerationToolInputWithFunction, ModelName } from "./types"; /** * Runs a set of tools on a given input and returns the final response in the same format as the AI.run call. * * @param {Ai} ai - The AI instance to use for the run. - * @param {BaseAiTextGenerationModels} model - The function calling model to use for the run. We recommend using `@hf/nousresearch/hermes-2-pro-mistral-7b`, `llama-3` or equivalent model that's suited for function calling. + * @param {ModelName} model - The function calling model to use for the run. We recommend using `@hf/nousresearch/hermes-2-pro-mistral-7b`, `llama-3` or equivalent model that's suited for function calling. * @param {Object} input - The input for the runWithTools call. * @param {RoleScopedChatInput[]} input.messages - The messages to be sent to the AI. * @param {AiTextGenerationToolInputWithFunction[]} input.tools - The tools to be used. You can also pass a function along with each tool that will automatically run the tool with the arguments passed to the function. The function arguments are type-checked against your tool's parameters, so you can get autocomplete and type checking in your IDE. + * @param {number} [input.max_tokens] - Maximum number of tokens to generate in the response. * @param {Object} config - Configuration options for the runWithTools call. * @param {boolean} [config.streamFinalResponse=false] - Whether to stream the final response or not. * @param {number} [config.maxRecursiveToolRuns=0] - The maximum number of recursive tool runs to perform. * @param {boolean} [config.strictValidation=false] - Whether to perform strict validation (using zod) of the arguments passed to the tools. * @param {boolean} [config.verbose=false] - Whether to enable verbose logging. - * @param {(tools: AiTextGenerationToolInputWithFunction[], ai: Ai, model: BaseAiTextGenerationModels, messages: RoleScopedChatInput[]) => Promise} [config.trimFunction] - Use a trim function to trim down the number of tools given to the AI for a given task. You can also use this alongside `autoTrimTools`, which uses an extra AI.run call to cut down on the input tokens of the tool call based on the tool's names. + * @param {(tools: AiTextGenerationToolInputWithFunction[], ai: Ai, model: ModelName, messages: RoleScopedChatInput[]) => Promise} [config.trimFunction] - Use a trim function to trim down the number of tools given to the AI for a given task. You can also use this alongside `autoTrimTools`, which uses an extra AI.run call to cut down on the input tokens of the tool call based on the tool's names. * * @returns {Promise} The final response in the same format as the AI.run call. */ @@ -30,13 +30,15 @@ export const runWithTools = async ( /** The AI instance to use for the run. */ ai: Ai, /** The function calling model to use for the run. We recommend using `@hf/nousresearch/hermes-2-pro-mistral-7b`, `llama-3` or equivalent model that's suited for function calling. */ - model: BaseAiTextGenerationModels, + model: ModelName, /** The input for the runWithTools call. */ input: { /** The messages to be sent to the AI. */ messages: RoleScopedChatInput[]; /** The tools to be used. You can also pass a function along with each tool that will Automatically run the tool with the arguments passed to the function. The function arguments are type-checked against your tool's parameters, so you can get autocomplete and type checking in your IDE. */ tools: AiTextGenerationToolInputWithFunction[]; + /** Maximum number of tokens to generate in the response. */ + max_tokens?: number; }, /** Configuration options for the runWithTools call. */ config: { @@ -53,7 +55,7 @@ export const runWithTools = async ( trimFunction?: ( tools: AiTextGenerationToolInputWithFunction[], ai: Ai, - model: BaseAiTextGenerationModels, + model: ModelName, messages: RoleScopedChatInput[], ) => Promise; } = {}, @@ -66,7 +68,7 @@ export const runWithTools = async ( trimFunction = async ( tools: AiTextGenerationToolInputWithFunction[], ai: Ai, - model: BaseAiTextGenerationModels, + model: ModelName, messages: RoleScopedChatInput[], ) => tools as AiTextGenerationToolInputWithFunction[], strictValidation = false, @@ -112,7 +114,7 @@ export const runWithTools = async ( maxRecursiveToolRuns, }: { ai: Ai; - model: BaseAiTextGenerationModels; + model: ModelName; messages: RoleScopedChatInput[]; streamFinalResponse: boolean; maxRecursiveToolRuns: number; @@ -127,6 +129,7 @@ export const runWithTools = async ( messages: messages, stream: false, tools: tools, + ...(input.max_tokens !== undefined && { max_tokens: input.max_tokens }), })) as { response?: string; tool_calls?: { @@ -197,7 +200,6 @@ export const runWithTools = async ( messages.push({ role: "tool", content: JSON.stringify(result), - // @ts-expect-error workerd types name: selectedTool.name, }); } catch (error) { @@ -205,7 +207,6 @@ export const runWithTools = async ( messages.push({ role: "tool", content: `Error executing tool ${selectedTool.name}: ${(error as Error).message}`, - // @ts-expect-error workerd types name: selectedTool.name, }); } @@ -237,6 +238,7 @@ export const runWithTools = async ( const finalResponse = await ai.run(model, { messages: messages, stream: streamFinalResponse, + ...(input.max_tokens !== undefined && { max_tokens: input.max_tokens }), }); totalCharacters += JSON.stringify(messages).length; Logger.info( @@ -244,7 +246,7 @@ export const runWithTools = async ( ); Logger.info(`Total number of characters: ${totalCharacters}`); - return finalResponse; + return finalResponse as AiTextGenerationOutput; } } catch (error) { Logger.error("Error in runAndProcessToolCall:", error); diff --git a/src/types/index.ts b/src/types/index.ts index fd6ddcf..4141ca7 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -1,6 +1,11 @@ -import { AiTextGenerationToolInput } from "@cloudflare/workers-types"; +import { AiTextGenerationToolInput, AiModels } from "@cloudflare/workers-types"; import { JSONSchema7 } from "json-schema"; +/** + * Model names available in Workers AI + */ +export type ModelName = keyof AiModels; + export type UppercaseHttpMethod = | "GET" | "POST" diff --git a/src/utils.ts b/src/utils.ts index 789973b..eec00f4 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -3,10 +3,9 @@ import { OpenAPIV3 } from "./types/openapi-schema"; import { JSONSchema7 } from "json-schema"; import { ZodTypeAny, z } from "zod"; import { Logger } from "./logger"; -import { AiTextGenerationToolInputWithFunction } from "./types"; +import { AiTextGenerationToolInputWithFunction, ModelName } from "./types"; import { Ai, - BaseAiTextGenerationModels, RoleScopedChatInput, } from "@cloudflare/workers-types"; @@ -92,7 +91,7 @@ export function validateArgsWithZod( export async function autoTrimTools( tools: AiTextGenerationToolInputWithFunction[], ai: Ai, - model: BaseAiTextGenerationModels, + model: ModelName, messages: RoleScopedChatInput[], ) { let returnedTools = tools;