Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 11 additions & 18 deletions packages/app/src/components/settings/VectorModelSettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { Switch } from "@/components/ui/switch";
import { useVectorModelStore } from "@/stores/vector-model-store";
import { BUILTIN_EMBEDDING_MODELS } from "@readany/core/ai/builtin-embedding-models";
import { clearModelCache, loadEmbeddingPipeline } from "@readany/core/ai/local-embedding-service";
import { requestRemoteEmbeddingBatch } from "@readany/core/rag";
import type { VectorModelConfig } from "@readany/core/types";
import { Check, Download, Edit2, Loader2, Plus, Trash2, X } from "lucide-react";
import { useCallback, useState } from "react";
Expand Down Expand Up @@ -267,24 +268,16 @@ function RemoteModelsSection() {
setTestResults((prev) => ({ ...prev, [model.id]: t("settings.vm_testing") }));
try {
const testUrl = normalizeEmbeddingsUrl(model.url);
const isOllama = testUrl.endsWith("/api/embed");
const headers: Record<string, string> = { "Content-Type": "application/json" };
if (model.apiKey.trim()) headers.Authorization = `Bearer ${model.apiKey}`;

const requestBody = isOllama
? { model: model.modelId, input: "test" }
: { input: ["test"], model: model.modelId, encoding_format: "float" };

const res = await fetch(testUrl, {
method: "POST",
headers,
body: JSON.stringify(requestBody),
});
if (!res.ok) throw new Error(`HTTP ${res.status}: ${res.statusText}`);
const json = await res.json();
const len = isOllama
? (json?.embeddings?.[0]?.length ?? 0)
: (json?.data?.[0]?.embedding?.length ?? 0);
const result = await requestRemoteEmbeddingBatch(
{
url: testUrl,
modelId: model.modelId,
apiKey: model.apiKey,
},
["test"],
);
if (!result.ok) throw new Error(`HTTP ${result.status}: ${result.errorText}`);
const len = result.embeddings[0]?.length ?? 0;

updateVectorModel(model.id, { dimension: len });
setTestResults((prev) => ({
Expand Down
8 changes: 8 additions & 0 deletions packages/core/src/rag/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ export type { TextSegment, ChapterData } from "./rag-types";
export { chunkContent, estimateTokens } from "./chunker";
export type { ChunkerConfig } from "./chunker";

export { requestRemoteEmbeddingBatch, isOllamaEmbeddingUrl } from "./remote-embedding";
export type {
RemoteEmbeddingBatchOptions,
RemoteEmbeddingBatchResult,
RemoteEmbeddingFetch,
RemoteEmbeddingModel,
} from "./remote-embedding";

export { EmbeddingService } from "./embedding-service";
export type { EmbeddingConfig } from "./embedding-service";

Expand Down
125 changes: 125 additions & 0 deletions packages/core/src/rag/remote-embedding.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import { afterEach, describe, expect, it, vi } from "vitest";
import type { IPlatformService } from "../services/platform";
import { setPlatformService } from "../services/platform";
import { requestRemoteEmbeddingBatch } from "./remote-embedding";

function createPlatform(fetchImpl: IPlatformService["fetch"]): IPlatformService {
return {
platformType: "desktop",
isMobile: false,
isDesktop: true,
readFile: vi.fn(),
writeFile: vi.fn(),
writeTextFile: vi.fn(),
readTextFile: vi.fn(),
mkdir: vi.fn(),
exists: vi.fn(),
deleteFile: vi.fn(),
getAppDataDir: vi.fn(),
getDataDir: vi.fn(),
joinPath: vi.fn(),
convertFileSrc: vi.fn(),
pickFile: vi.fn(),
loadDatabase: vi.fn(),
fetch: fetchImpl,
createWebSocket: vi.fn(),
getAppVersion: vi.fn(),
kvGetItem: vi.fn(),
kvSetItem: vi.fn(),
kvRemoveItem: vi.fn(),
kvGetAllKeys: vi.fn(),
copyToClipboard: vi.fn(),
shareOrDownloadFile: vi.fn(),
};
}

describe("requestRemoteEmbeddingBatch", () => {
afterEach(() => {
setPlatformService(null as unknown as IPlatformService);
vi.restoreAllMocks();
});

it("uses the platform fetch implementation for OpenAI-compatible embeddings", async () => {
const fetchMock = vi.fn<IPlatformService["fetch"]>(async () =>
Response.json({
data: [
{ index: 1, embedding: [0.3, 0.4] },
{ index: 0, embedding: [0.1, 0.2] },
],
}),
);
setPlatformService(createPlatform(fetchMock));

const result = await requestRemoteEmbeddingBatch(
{
url: "http://localhost:11434/v1/embeddings",
modelId: "bge-m3",
apiKey: "ollama",
},
["a", "b"],
);

expect(fetchMock).toHaveBeenCalledWith(
"http://localhost:11434/v1/embeddings",
expect.objectContaining({
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: "Bearer ollama",
},
}),
);
expect(result).toEqual({
ok: true,
embeddings: [
[0.1, 0.2],
[0.3, 0.4],
],
});
});

it("builds Ollama /api/embed requests and omits blank authorization", async () => {
const fetchMock = vi.fn<IPlatformService["fetch"]>(async () =>
Response.json({
embeddings: [[0.1, 0.2]],
}),
);
setPlatformService(createPlatform(fetchMock));

const result = await requestRemoteEmbeddingBatch(
{
url: "http://localhost:11434/api/embed",
modelId: "bge-m3",
apiKey: " ",
},
["hello".repeat(10)],
{ maxCharsPerInput: 8 },
);

const [, init] = fetchMock.mock.calls[0];
expect(init?.headers).toEqual({ "Content-Type": "application/json" });
expect(JSON.parse(String(init?.body))).toEqual({
model: "bge-m3",
input: ["hellohel"],
});
expect(result).toEqual({ ok: true, embeddings: [[0.1, 0.2]] });
});

it("returns status and response text for API errors", async () => {
const fetchMock = vi.fn<IPlatformService["fetch"]>(
async () => new Response("bad model", { status: 404 }),
);
setPlatformService(createPlatform(fetchMock));

const result = await requestRemoteEmbeddingBatch(
{
url: "http://localhost:11434/v1/embeddings",
modelId: "missing",
apiKey: "",
},
["test"],
);

expect(result).toEqual({ ok: false, status: 404, errorText: "bad model" });
});
});
98 changes: 98 additions & 0 deletions packages/core/src/rag/remote-embedding.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import { getPlatformService } from "../services/platform";

export interface RemoteEmbeddingModel {
url: string;
modelId: string;
apiKey: string;
}

export type RemoteEmbeddingBatchResult =
| { ok: true; embeddings: number[][] }
| { ok: false; status: number; errorText: string };

export type RemoteEmbeddingFetch = (url: string, init: RequestInit) => Promise<Response>;

export interface RemoteEmbeddingBatchOptions {
fetchImpl?: RemoteEmbeddingFetch;
maxCharsPerInput?: number;
}

interface OpenAIEmbeddingItem {
embedding: number[];
index: number;
}

export async function requestRemoteEmbeddingBatch(
model: RemoteEmbeddingModel,
inputTexts: string[],
options: RemoteEmbeddingBatchOptions = {},
): Promise<RemoteEmbeddingBatchResult> {
const isOllama = isOllamaEmbeddingUrl(model.url);
const maxCharsPerInput = options.maxCharsPerInput;
const safeTexts =
typeof maxCharsPerInput === "number" && maxCharsPerInput > 0
? inputTexts.map((text) =>
text.length > maxCharsPerInput ? text.slice(0, maxCharsPerInput) : text,
)
: inputTexts;
const requestBody = isOllama
? { model: model.modelId, input: safeTexts }
: {
input: safeTexts,
model: model.modelId,
encoding_format: "float",
};

const headers: Record<string, string> = {
"Content-Type": "application/json",
};
if (model.apiKey.trim()) {
headers.Authorization = `Bearer ${model.apiKey}`;
}

const fetchImpl = options.fetchImpl ?? getRemoteEmbeddingFetch();
const response = await fetchImpl(model.url, {
method: "POST",
headers,
body: JSON.stringify(requestBody),
});

if (!response.ok) {
const errorText = await response.text().catch(() => "");
return { ok: false, status: response.status, errorText };
}

const json = await response.json();
return {
ok: true,
embeddings: parseRemoteEmbeddingResponse(json, isOllama),
};
}

export function isOllamaEmbeddingUrl(url: string): boolean {
return url.replace(/\/$/, "").endsWith("/api/embed");
}

function getRemoteEmbeddingFetch(): RemoteEmbeddingFetch {
try {
const platform = getPlatformService();
return (url, init) => platform.fetch(url, init);
} catch {
return (url, init) => globalThis.fetch(url, init);
}
}

function parseRemoteEmbeddingResponse(json: unknown, isOllama: boolean): number[][] {
if (isOllama) {
const embeddings = (json as { embeddings?: number[][] })?.embeddings;
return Array.isArray(embeddings) ? embeddings : [];
}

const data = (json as { data?: OpenAIEmbeddingItem[] })?.data;
if (!Array.isArray(data)) return [];

return data
.slice()
.sort((a, b) => a.index - b.index)
.map((item) => item.embedding);
}
66 changes: 17 additions & 49 deletions packages/core/src/rag/vectorize-trigger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type { VectorizeProgress } from "../types";
import { eventBus } from "../utils/event-bus";
import { chunkContent } from "./chunker";
import type { ChapterData } from "./rag-types";
import { requestRemoteEmbeddingBatch } from "./remote-embedding";
import { invalidateChunkCache } from "./search";
import { getVectorDB, hasVectorDB } from "./vector-db";
import type { VectorRecord } from "./vector-db";
Expand Down Expand Up @@ -166,13 +167,17 @@ export async function triggerVectorizeBook(
if (vectorDB && (await vectorDB.isReady())) {
await vectorDB.deleteByBookId(bookId);

const vectorRecords: VectorRecord[] = allChunks
.filter((c) => c.embedding && c.embedding.length > 0)
.map((c) => ({
id: c.id,
bookId: c.bookId,
embedding: c.embedding!,
}));
const vectorRecords: VectorRecord[] = allChunks.flatMap((c) =>
c.embedding && c.embedding.length > 0
? [
{
id: c.id,
bookId: c.bookId,
embedding: c.embedding,
},
]
: [],
);

if (vectorRecords.length > 0) {
// Detect actual embedding dimension and reinit vector DB if needed
Expand Down Expand Up @@ -283,14 +288,6 @@ async function generateRemoteEmbeddings(
);
}

const isOllama = selectedModel.url.endsWith("/api/embed");
const headers: Record<string, string> = {
"Content-Type": "application/json",
};
if (selectedModel.apiKey.trim()) {
headers.Authorization = `Bearer ${selectedModel.apiKey}`;
}

// Conservative defaults: many Chinese embedding APIs (Baidu/MiniMax/etc.)
// cap single-input tokens at 384–1024 and batch size at 16. Send 8 per
// request and keep single chunks ≤ 1800 chars (~ ≤ 1800 tokens for CJK)
Expand All @@ -300,41 +297,12 @@ async function generateRemoteEmbeddings(

const callEmbeddingApi = async (
inputTexts: string[],
): Promise<{ ok: true; embeddings: number[][] } | { ok: false; status: number; errorText: string }> => {
const safeTexts = inputTexts.map((t) =>
t.length > MAX_CHARS_PER_CHUNK ? t.slice(0, MAX_CHARS_PER_CHUNK) : t,
);
const requestBody = isOllama
? { model: selectedModel.modelId, input: safeTexts }
: {
input: safeTexts,
model: selectedModel.modelId,
encoding_format: "float",
};

const res = await fetch(selectedModel.url, {
method: "POST",
headers,
body: JSON.stringify(requestBody),
): Promise<
{ ok: true; embeddings: number[][] } | { ok: false; status: number; errorText: string }
> => {
return requestRemoteEmbeddingBatch(selectedModel, inputTexts, {
maxCharsPerInput: MAX_CHARS_PER_CHUNK,
});

if (!res.ok) {
const errorText = await res.text().catch(() => "");
return { ok: false, status: res.status, errorText };
}

const json = await res.json();
const embeddings: number[][] = isOllama
? (json?.embeddings ?? [])
: (
(json?.data ?? []) as Array<{
embedding: number[];
index: number;
}>
)
.sort((a: any, b: any) => a.index - b.index)
.map((d: any) => d.embedding);
return { ok: true, embeddings };
};

for (let i = 0; i < chunks.length; i += batchSize) {
Expand Down