Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions packages/core/src/ai/__tests__/llm-provider.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import { afterEach, describe, expect, it, vi } from "vitest";
import type { AIEndpoint } from "../../types";
import { getEndpointFetch } from "../llm-provider";

const originalFetch = globalThis.fetch;

function makeEndpoint(overrides: Partial<AIEndpoint> = {}): AIEndpoint {
return {
id: "endpoint-1",
name: "Test",
provider: "custom",
apiKey: "test-key",
baseUrl: "https://api.example.com/v1/chat/completions",
useExactRequestUrl: true,
models: ["test-model"],
modelsFetched: true,
...overrides,
};
}

function makeToolCall(extraContent?: Record<string, unknown>) {
return {
id: "call_1",
type: "function",
function: {
name: "getCurrentChapter",
arguments: "{}",
},
...(extraContent ? { extra_content: extraContent } : {}),
};
}

function getFirstToolCall(body: Record<string, unknown>): Record<string, unknown> {
const messages = body.messages as Array<Record<string, unknown>>;
const toolCalls = messages[0].tool_calls as Array<Record<string, unknown>>;
return toolCalls[0];
}

async function captureRequestBody(
endpoint: AIEndpoint,
model: string,
body: Record<string, unknown>,
): Promise<Record<string, unknown>> {
let capturedBody = "";
globalThis.fetch = vi.fn(async (input: RequestInfo | URL, init?: RequestInit) => {
if (typeof init?.body === "string") {
capturedBody = init.body;
} else if (input instanceof Request) {
capturedBody = await input.clone().text();
}

return new Response("{}", {
status: 200,
headers: { "content-type": "application/json" },
});
}) as typeof fetch;

const endpointFetch = getEndpointFetch(endpoint, model);
await endpointFetch(endpoint.baseUrl, {
method: "POST",
headers: { "content-type": "application/json" },
body: JSON.stringify(body),
});

return JSON.parse(capturedBody) as Record<string, unknown>;
}

afterEach(() => {
globalThis.fetch = originalFetch;
});

describe("getEndpointFetch Gemini thought signatures", () => {
it("adds a Gemini thought signature bypass for gemini-3 OpenAI-compatible tool calls", async () => {
const patchedBody = await captureRequestBody(
makeEndpoint({ models: ["gemini-3-flash-preview"] }),
"gemini-3-flash-preview",
{
messages: [
{
role: "assistant",
content: "",
tool_calls: [makeToolCall()],
},
],
},
);

const toolCall = getFirstToolCall(patchedBody);
const extraContent = toolCall.extra_content as Record<string, Record<string, string>>;
expect(extraContent.google.thought_signature).toBe("skip_thought_signature_validator");
});

it("preserves an existing Gemini thought signature", async () => {
const patchedBody = await captureRequestBody(
makeEndpoint({
provider: "google",
baseUrl: "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
}),
"gemini-3-flash-preview",
{
messages: [
{
role: "assistant",
content: "",
tool_calls: [
makeToolCall({
google: {
thought_signature: "real-signature",
},
}),
],
},
],
},
);

const toolCall = getFirstToolCall(patchedBody);
const extraContent = toolCall.extra_content as Record<string, Record<string, string>>;
expect(extraContent.google.thought_signature).toBe("real-signature");
});

it("does not modify non-Gemini OpenAI-compatible requests", async () => {
const patchedBody = await captureRequestBody(
makeEndpoint({
baseUrl: "https://api.openai.com/v1/chat/completions",
models: ["gpt-4o-mini"],
}),
"gpt-4o-mini",
{
messages: [
{
role: "assistant",
content: "",
tool_calls: [makeToolCall()],
},
],
},
);

const toolCall = getFirstToolCall(patchedBody);
expect(toolCall.extra_content).toBeUndefined();
});
});
145 changes: 140 additions & 5 deletions packages/core/src/ai/llm-provider.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import type { BaseChatModel } from "@langchain/core/language_models/chat_models";
import type { AIConfig, AIEndpoint } from "../types";
import { logAIEndpointDebug, summarizeDebugText } from "./request-debug";
import { providerRequiresApiKey } from "../utils";
import { formatApiHost } from "../utils/api";
import { logAIEndpointDebug, summarizeDebugText } from "./request-debug";

/**
* Optional custom fetch for streaming support (e.g. expo/fetch in React Native).
Expand Down Expand Up @@ -66,6 +66,129 @@ function sanitizeCustomHeaders(headers?: Headers): Headers | undefined {
return sanitized;
}

const GEMINI_THOUGHT_SIGNATURE_BYPASS = "skip_thought_signature_validator";

function shouldPatchGeminiThoughtSignatures(
endpoint: AIEndpoint,
model: string | undefined,
requestUrl: string,
): boolean {
const modelName = model?.toLowerCase() ?? "";
const targetUrl = `${requestUrl} ${endpoint.baseUrl ?? ""}`.toLowerCase();

return (
endpoint.provider === "google" ||
targetUrl.includes("generativelanguage.googleapis.com") ||
modelName.startsWith("gemini-3")
);
}

function hasGeminiThoughtSignature(toolCall: Record<string, unknown>): boolean {
const extraContent = toolCall.extra_content;
if (!extraContent || typeof extraContent !== "object") return false;

const google = (extraContent as Record<string, unknown>).google;
if (!google || typeof google !== "object") return false;

return typeof (google as Record<string, unknown>).thought_signature === "string";
}

function patchGeminiThoughtSignatureBody(bodyText: string): string | undefined {
let payload: unknown;
try {
payload = JSON.parse(bodyText);
} catch {
return undefined;
}

if (!payload || typeof payload !== "object") return undefined;

const messages = (payload as Record<string, unknown>).messages;
if (!Array.isArray(messages)) return undefined;

let changed = false;
for (const message of messages) {
if (!message || typeof message !== "object") continue;

const messageRecord = message as Record<string, unknown>;
if (messageRecord.role !== "assistant") continue;

const toolCalls = messageRecord.tool_calls;
if (!Array.isArray(toolCalls)) continue;

const firstFunctionToolCall = toolCalls.find(
(toolCall): toolCall is Record<string, unknown> =>
Boolean(toolCall) &&
typeof toolCall === "object" &&
(toolCall as Record<string, unknown>).type === "function",
);
if (!firstFunctionToolCall || hasGeminiThoughtSignature(firstFunctionToolCall)) continue;

const extraContent =
typeof firstFunctionToolCall.extra_content === "object" &&
firstFunctionToolCall.extra_content !== null
? { ...(firstFunctionToolCall.extra_content as Record<string, unknown>) }
: {};
const google =
typeof extraContent.google === "object" && extraContent.google !== null
? { ...(extraContent.google as Record<string, unknown>) }
: {};

firstFunctionToolCall.extra_content = {
...extraContent,
google: {
...google,
thought_signature: GEMINI_THOUGHT_SIGNATURE_BYPASS,
},
};
changed = true;
}

return changed ? JSON.stringify(payload) : undefined;
}

async function patchGeminiThoughtSignatureRequest(
input: RequestInfo | URL,
init?: RequestInit,
): Promise<{ input: RequestInfo | URL; init?: RequestInit } | undefined> {
if (typeof init?.body === "string") {
const body = patchGeminiThoughtSignatureBody(init.body);
if (!body) return undefined;

const headers = init.headers ? new Headers(init.headers) : undefined;
headers?.delete("content-length");

return {
input,
init: {
...init,
...(headers ? { headers } : {}),
body,
},
};
}

if (!isRequestLike(input)) return undefined;

let sourceBody: string;
try {
sourceBody = await input.clone().text();
} catch {
return undefined;
}

const body = patchGeminiThoughtSignatureBody(sourceBody);
if (!body) return undefined;

const headers = new Headers(input.headers);
headers.delete("content-length");

return {
input: new Request(input, { body, headers }),
init,
};
}

export function getEndpointFetch(endpoint: AIEndpoint, model?: string): typeof globalThis.fetch {
const exactUrl = endpoint.useExactRequestUrl ? endpoint.baseUrl?.trim() : "";
const baseFetch = (_streamingFetch ?? globalThis.fetch).bind(globalThis);
Expand Down Expand Up @@ -94,6 +217,17 @@ export function getEndpointFetch(endpoint: AIEndpoint, model?: string): typeof g
}
}

if (
requestMethod.toUpperCase() === "POST" &&
shouldPatchGeminiThoughtSignatures(endpoint, model, requestUrl)
) {
const patched = await patchGeminiThoughtSignatureRequest(requestInput, requestInit);
if (patched) {
requestInput = patched.input;
requestInit = patched.init;
}
}

logAIEndpointDebug("request", endpoint, {
action: "langchain-chat",
method: requestMethod,
Expand Down Expand Up @@ -271,10 +405,11 @@ export async function createChatModelFromEndpoint(
if (endpoint.useExactRequestUrl && endpoint.baseUrl) {
geminiBaseUrl = endpoint.baseUrl.trim();
} else {
const rawBase = (endpoint.baseUrl || "https://generativelanguage.googleapis.com").replace(/\/+$/, "");
geminiBaseUrl = rawBase.includes("/v1beta/openai")
? rawBase
: `${rawBase}/v1beta/openai`;
const rawBase = (endpoint.baseUrl || "https://generativelanguage.googleapis.com").replace(
/\/+$/,
"",
);
geminiBaseUrl = rawBase.includes("/v1beta/openai") ? rawBase : `${rawBase}/v1beta/openai`;
}

return new ChatOpenAI({
Expand Down