From 92be00781f88f198f6331ec79b2b70ea4f31a225 Mon Sep 17 00:00:00 2001 From: codedogQBY <1369175442@qq.com> Date: Sat, 13 Jun 2026 05:38:32 +0800 Subject: [PATCH] fix(rag): route remote embeddings through platform fetch --- .../settings/VectorModelSettings.tsx | 29 ++-- packages/core/src/rag/index.ts | 8 ++ .../core/src/rag/remote-embedding.test.ts | 125 ++++++++++++++++++ packages/core/src/rag/remote-embedding.ts | 98 ++++++++++++++ packages/core/src/rag/vectorize-trigger.ts | 66 +++------ 5 files changed, 259 insertions(+), 67 deletions(-) create mode 100644 packages/core/src/rag/remote-embedding.test.ts create mode 100644 packages/core/src/rag/remote-embedding.ts diff --git a/packages/app/src/components/settings/VectorModelSettings.tsx b/packages/app/src/components/settings/VectorModelSettings.tsx index d73f04fe..b42726a3 100644 --- a/packages/app/src/components/settings/VectorModelSettings.tsx +++ b/packages/app/src/components/settings/VectorModelSettings.tsx @@ -11,6 +11,7 @@ import { Switch } from "@/components/ui/switch"; import { useVectorModelStore } from "@/stores/vector-model-store"; import { BUILTIN_EMBEDDING_MODELS } from "@readany/core/ai/builtin-embedding-models"; import { clearModelCache, loadEmbeddingPipeline } from "@readany/core/ai/local-embedding-service"; +import { requestRemoteEmbeddingBatch } from "@readany/core/rag"; import type { VectorModelConfig } from "@readany/core/types"; import { Check, Download, Edit2, Loader2, Plus, Trash2, X } from "lucide-react"; import { useCallback, useState } from "react"; @@ -267,24 +268,16 @@ function RemoteModelsSection() { setTestResults((prev) => ({ ...prev, [model.id]: t("settings.vm_testing") })); try { const testUrl = normalizeEmbeddingsUrl(model.url); - const isOllama = testUrl.endsWith("/api/embed"); - const headers: Record = { "Content-Type": "application/json" }; - if (model.apiKey.trim()) headers.Authorization = `Bearer ${model.apiKey}`; - - const requestBody = isOllama - ? { model: model.modelId, input: "test" } - : { input: ["test"], model: model.modelId, encoding_format: "float" }; - - const res = await fetch(testUrl, { - method: "POST", - headers, - body: JSON.stringify(requestBody), - }); - if (!res.ok) throw new Error(`HTTP ${res.status}: ${res.statusText}`); - const json = await res.json(); - const len = isOllama - ? (json?.embeddings?.[0]?.length ?? 0) - : (json?.data?.[0]?.embedding?.length ?? 0); + const result = await requestRemoteEmbeddingBatch( + { + url: testUrl, + modelId: model.modelId, + apiKey: model.apiKey, + }, + ["test"], + ); + if (!result.ok) throw new Error(`HTTP ${result.status}: ${result.errorText}`); + const len = result.embeddings[0]?.length ?? 0; updateVectorModel(model.id, { dimension: len }); setTestResults((prev) => ({ diff --git a/packages/core/src/rag/index.ts b/packages/core/src/rag/index.ts index 722a952e..2aa88d78 100644 --- a/packages/core/src/rag/index.ts +++ b/packages/core/src/rag/index.ts @@ -3,6 +3,14 @@ export type { TextSegment, ChapterData } from "./rag-types"; export { chunkContent, estimateTokens } from "./chunker"; export type { ChunkerConfig } from "./chunker"; +export { requestRemoteEmbeddingBatch, isOllamaEmbeddingUrl } from "./remote-embedding"; +export type { + RemoteEmbeddingBatchOptions, + RemoteEmbeddingBatchResult, + RemoteEmbeddingFetch, + RemoteEmbeddingModel, +} from "./remote-embedding"; + export { EmbeddingService } from "./embedding-service"; export type { EmbeddingConfig } from "./embedding-service"; diff --git a/packages/core/src/rag/remote-embedding.test.ts b/packages/core/src/rag/remote-embedding.test.ts new file mode 100644 index 00000000..bbcfbb1e --- /dev/null +++ b/packages/core/src/rag/remote-embedding.test.ts @@ -0,0 +1,125 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import type { IPlatformService } from "../services/platform"; +import { setPlatformService } from "../services/platform"; +import { requestRemoteEmbeddingBatch } from "./remote-embedding"; + +function createPlatform(fetchImpl: IPlatformService["fetch"]): IPlatformService { + return { + platformType: "desktop", + isMobile: false, + isDesktop: true, + readFile: vi.fn(), + writeFile: vi.fn(), + writeTextFile: vi.fn(), + readTextFile: vi.fn(), + mkdir: vi.fn(), + exists: vi.fn(), + deleteFile: vi.fn(), + getAppDataDir: vi.fn(), + getDataDir: vi.fn(), + joinPath: vi.fn(), + convertFileSrc: vi.fn(), + pickFile: vi.fn(), + loadDatabase: vi.fn(), + fetch: fetchImpl, + createWebSocket: vi.fn(), + getAppVersion: vi.fn(), + kvGetItem: vi.fn(), + kvSetItem: vi.fn(), + kvRemoveItem: vi.fn(), + kvGetAllKeys: vi.fn(), + copyToClipboard: vi.fn(), + shareOrDownloadFile: vi.fn(), + }; +} + +describe("requestRemoteEmbeddingBatch", () => { + afterEach(() => { + setPlatformService(null as unknown as IPlatformService); + vi.restoreAllMocks(); + }); + + it("uses the platform fetch implementation for OpenAI-compatible embeddings", async () => { + const fetchMock = vi.fn(async () => + Response.json({ + data: [ + { index: 1, embedding: [0.3, 0.4] }, + { index: 0, embedding: [0.1, 0.2] }, + ], + }), + ); + setPlatformService(createPlatform(fetchMock)); + + const result = await requestRemoteEmbeddingBatch( + { + url: "http://localhost:11434/v1/embeddings", + modelId: "bge-m3", + apiKey: "ollama", + }, + ["a", "b"], + ); + + expect(fetchMock).toHaveBeenCalledWith( + "http://localhost:11434/v1/embeddings", + expect.objectContaining({ + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: "Bearer ollama", + }, + }), + ); + expect(result).toEqual({ + ok: true, + embeddings: [ + [0.1, 0.2], + [0.3, 0.4], + ], + }); + }); + + it("builds Ollama /api/embed requests and omits blank authorization", async () => { + const fetchMock = vi.fn(async () => + Response.json({ + embeddings: [[0.1, 0.2]], + }), + ); + setPlatformService(createPlatform(fetchMock)); + + const result = await requestRemoteEmbeddingBatch( + { + url: "http://localhost:11434/api/embed", + modelId: "bge-m3", + apiKey: " ", + }, + ["hello".repeat(10)], + { maxCharsPerInput: 8 }, + ); + + const [, init] = fetchMock.mock.calls[0]; + expect(init?.headers).toEqual({ "Content-Type": "application/json" }); + expect(JSON.parse(String(init?.body))).toEqual({ + model: "bge-m3", + input: ["hellohel"], + }); + expect(result).toEqual({ ok: true, embeddings: [[0.1, 0.2]] }); + }); + + it("returns status and response text for API errors", async () => { + const fetchMock = vi.fn( + async () => new Response("bad model", { status: 404 }), + ); + setPlatformService(createPlatform(fetchMock)); + + const result = await requestRemoteEmbeddingBatch( + { + url: "http://localhost:11434/v1/embeddings", + modelId: "missing", + apiKey: "", + }, + ["test"], + ); + + expect(result).toEqual({ ok: false, status: 404, errorText: "bad model" }); + }); +}); diff --git a/packages/core/src/rag/remote-embedding.ts b/packages/core/src/rag/remote-embedding.ts new file mode 100644 index 00000000..fb34029b --- /dev/null +++ b/packages/core/src/rag/remote-embedding.ts @@ -0,0 +1,98 @@ +import { getPlatformService } from "../services/platform"; + +export interface RemoteEmbeddingModel { + url: string; + modelId: string; + apiKey: string; +} + +export type RemoteEmbeddingBatchResult = + | { ok: true; embeddings: number[][] } + | { ok: false; status: number; errorText: string }; + +export type RemoteEmbeddingFetch = (url: string, init: RequestInit) => Promise; + +export interface RemoteEmbeddingBatchOptions { + fetchImpl?: RemoteEmbeddingFetch; + maxCharsPerInput?: number; +} + +interface OpenAIEmbeddingItem { + embedding: number[]; + index: number; +} + +export async function requestRemoteEmbeddingBatch( + model: RemoteEmbeddingModel, + inputTexts: string[], + options: RemoteEmbeddingBatchOptions = {}, +): Promise { + const isOllama = isOllamaEmbeddingUrl(model.url); + const maxCharsPerInput = options.maxCharsPerInput; + const safeTexts = + typeof maxCharsPerInput === "number" && maxCharsPerInput > 0 + ? inputTexts.map((text) => + text.length > maxCharsPerInput ? text.slice(0, maxCharsPerInput) : text, + ) + : inputTexts; + const requestBody = isOllama + ? { model: model.modelId, input: safeTexts } + : { + input: safeTexts, + model: model.modelId, + encoding_format: "float", + }; + + const headers: Record = { + "Content-Type": "application/json", + }; + if (model.apiKey.trim()) { + headers.Authorization = `Bearer ${model.apiKey}`; + } + + const fetchImpl = options.fetchImpl ?? getRemoteEmbeddingFetch(); + const response = await fetchImpl(model.url, { + method: "POST", + headers, + body: JSON.stringify(requestBody), + }); + + if (!response.ok) { + const errorText = await response.text().catch(() => ""); + return { ok: false, status: response.status, errorText }; + } + + const json = await response.json(); + return { + ok: true, + embeddings: parseRemoteEmbeddingResponse(json, isOllama), + }; +} + +export function isOllamaEmbeddingUrl(url: string): boolean { + return url.replace(/\/$/, "").endsWith("/api/embed"); +} + +function getRemoteEmbeddingFetch(): RemoteEmbeddingFetch { + try { + const platform = getPlatformService(); + return (url, init) => platform.fetch(url, init); + } catch { + return (url, init) => globalThis.fetch(url, init); + } +} + +function parseRemoteEmbeddingResponse(json: unknown, isOllama: boolean): number[][] { + if (isOllama) { + const embeddings = (json as { embeddings?: number[][] })?.embeddings; + return Array.isArray(embeddings) ? embeddings : []; + } + + const data = (json as { data?: OpenAIEmbeddingItem[] })?.data; + if (!Array.isArray(data)) return []; + + return data + .slice() + .sort((a, b) => a.index - b.index) + .map((item) => item.embedding); +} diff --git a/packages/core/src/rag/vectorize-trigger.ts b/packages/core/src/rag/vectorize-trigger.ts index 9ca7b0bc..767d7267 100644 --- a/packages/core/src/rag/vectorize-trigger.ts +++ b/packages/core/src/rag/vectorize-trigger.ts @@ -14,6 +14,7 @@ import type { VectorizeProgress } from "../types"; import { eventBus } from "../utils/event-bus"; import { chunkContent } from "./chunker"; import type { ChapterData } from "./rag-types"; +import { requestRemoteEmbeddingBatch } from "./remote-embedding"; import { invalidateChunkCache } from "./search"; import { getVectorDB, hasVectorDB } from "./vector-db"; import type { VectorRecord } from "./vector-db"; @@ -166,13 +167,17 @@ export async function triggerVectorizeBook( if (vectorDB && (await vectorDB.isReady())) { await vectorDB.deleteByBookId(bookId); - const vectorRecords: VectorRecord[] = allChunks - .filter((c) => c.embedding && c.embedding.length > 0) - .map((c) => ({ - id: c.id, - bookId: c.bookId, - embedding: c.embedding!, - })); + const vectorRecords: VectorRecord[] = allChunks.flatMap((c) => + c.embedding && c.embedding.length > 0 + ? [ + { + id: c.id, + bookId: c.bookId, + embedding: c.embedding, + }, + ] + : [], + ); if (vectorRecords.length > 0) { // Detect actual embedding dimension and reinit vector DB if needed @@ -283,14 +288,6 @@ async function generateRemoteEmbeddings( ); } - const isOllama = selectedModel.url.endsWith("/api/embed"); - const headers: Record = { - "Content-Type": "application/json", - }; - if (selectedModel.apiKey.trim()) { - headers.Authorization = `Bearer ${selectedModel.apiKey}`; - } - // Conservative defaults: many Chinese embedding APIs (Baidu/MiniMax/etc.) // cap single-input tokens at 384–1024 and batch size at 16. Send 8 per // request and keep single chunks ≤ 1800 chars (~ ≤ 1800 tokens for CJK) @@ -300,41 +297,12 @@ async function generateRemoteEmbeddings( const callEmbeddingApi = async ( inputTexts: string[], - ): Promise<{ ok: true; embeddings: number[][] } | { ok: false; status: number; errorText: string }> => { - const safeTexts = inputTexts.map((t) => - t.length > MAX_CHARS_PER_CHUNK ? t.slice(0, MAX_CHARS_PER_CHUNK) : t, - ); - const requestBody = isOllama - ? { model: selectedModel.modelId, input: safeTexts } - : { - input: safeTexts, - model: selectedModel.modelId, - encoding_format: "float", - }; - - const res = await fetch(selectedModel.url, { - method: "POST", - headers, - body: JSON.stringify(requestBody), + ): Promise< + { ok: true; embeddings: number[][] } | { ok: false; status: number; errorText: string } + > => { + return requestRemoteEmbeddingBatch(selectedModel, inputTexts, { + maxCharsPerInput: MAX_CHARS_PER_CHUNK, }); - - if (!res.ok) { - const errorText = await res.text().catch(() => ""); - return { ok: false, status: res.status, errorText }; - } - - const json = await res.json(); - const embeddings: number[][] = isOllama - ? (json?.embeddings ?? []) - : ( - (json?.data ?? []) as Array<{ - embedding: number[]; - index: number; - }> - ) - .sort((a: any, b: any) => a.index - b.index) - .map((d: any) => d.embedding); - return { ok: true, embeddings }; }; for (let i = 0; i < chunks.length; i += batchSize) {