diff --git a/.github/copilot-setup-steps.yml b/.github/copilot-setup-steps.yml new file mode 100644 index 0000000..245bc2b --- /dev/null +++ b/.github/copilot-setup-steps.yml @@ -0,0 +1,3 @@ +steps: + - name: Install dependencies + run: npm install diff --git a/README.md b/README.md index c6fa75b..e3562c7 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # embedeer A Node.js tool for generating text embeddings using models from [Hugging Face](https://huggingface.co/models). -Supports **batched** input, **parallel** execution, isolated **child-process** workers (default) or **in-process threads**, quantization, and Hugging Face auth. +Supports **batched** input, **parallel** execution, isolated **child-process** workers (default) or **in-process threads**, quantization, optional GPU acceleration, and Hugging Face auth. --- @@ -12,6 +12,7 @@ Supports **batched** input, **parallel** execution, isolated **child-process** w - **In-process threads** — opt-in via `mode: 'thread'` for lower overhead - **Sequential** execution when `concurrency: 1` - Configurable batch size and concurrency +- **GPU acceleration** — optional CUDA (Linux x64) and DirectML (Windows x64), no extra packages needed - Hugging Face API token support (`--token` / `HF_TOKEN` env var) - Quantization via `dtype` (`fp32` · `fp16` · `q8` · `q4` · `q4f16` · `auto`) - Rich CLI: pull model, embed from file, dump output as JSON / TXT / SQL @@ -21,17 +22,27 @@ Supports **batched** input, **parallel** execution, isolated **child-process** w ## Installation ```bash -npm install +npm install @jsilvanus/embedeer +``` + +GPU acceleration (CUDA on Linux x64, DirectML on Windows x64) is built into `onnxruntime-node` +which ships as a transitive dependency. No additional packages are required. + +**For CUDA on Linux x64** you also need the CUDA 12 system libraries: + +```bash +# Ubuntu / Debian +sudo apt install cuda-toolkit-12-6 libcudnn9-cuda-12 ``` --- ## Programmatic API -### Embed texts +### Embed texts (CPU — default) ```js -import { Embedder } from 'embedeer'; +import { Embedder } from '@jsilvanus/embedeer'; const embedder = await Embedder.create('Xenova/all-MiniLM-L6-v2', { batchSize: 32, // texts per worker task (default: 32) @@ -50,12 +61,33 @@ const vectors = await embedder.embed(['Hello world', 'Foo bar baz']); await embedder.destroy(); // shut down worker processes ``` +### Embed texts with GPU + +```js +import { Embedder } from '@jsilvanus/embedeer'; + +// Auto-detect GPU (falls back to CPU if no provider is installed) +const embedder = await Embedder.create('Xenova/all-MiniLM-L6-v2', { + device: 'auto', +}); + +// Require GPU (throws if no provider is available) +const embedder = await Embedder.create('Xenova/all-MiniLM-L6-v2', { + device: 'gpu', +}); + +// Explicitly select an execution provider +const embedder = await Embedder.create('Xenova/all-MiniLM-L6-v2', { + provider: 'cuda', // 'cuda' | 'dml' +}); +``` + ### Pull (pre-cache) a model Like `ollama pull` — downloads the model once so workers start instantly: ```js -import { loadModel } from 'embedeer'; +import { loadModel } from '@jsilvanus/embedeer'; const { modelName, cacheDir } = await loadModel('Xenova/all-MiniLM-L6-v2', { token: 'hf_...', // optional @@ -63,72 +95,302 @@ const { modelName, cacheDir } = await loadModel('Xenova/all-MiniLM-L6-v2', { }); ``` -### Sequential execution +--- + +## CLI -```js -const embedder = await Embedder.create('Xenova/all-MiniLM-L6-v2', { concurrency: 1 }); ``` +npx @jsilvanus/embedeer [options] + +Model management (pull / cache model): + npx @jsilvanus/embedeer --model -### In-process threads (same process, lower overhead) +Embed texts (batch): + npx @jsilvanus/embedeer --model --data "text1" "text2" ... + npx @jsilvanus/embedeer --model --data '["text1","text2"]' + npx @jsilvanus/embedeer --model --file texts.txt + echo '["t1","t2"]' | npx @jsilvanus/embedeer --model + printf 'a\0b\0c' | npx @jsilvanus/embedeer --model --delimiter '\0' -```js -const embedder = await Embedder.create('Xenova/all-MiniLM-L6-v2', { mode: 'thread' }); +Interactive / streaming line-reader: + npx @jsilvanus/embedeer --model --interactive --dump out.jsonl + cat big.txt | npx @jsilvanus/embedeer --model -i --output csv --dump out.csv + +Options: + -m, --model Hugging Face model (default: Xenova/all-MiniLM-L6-v2) + -d, --data Text(s) or JSON array to embed + --file Input file: JSON array or delimited texts + -D, --delimiter Record separator for stdin/file (default: \n) + Escape sequences supported: \0 \n \t \r + -i, --interactive Interactive line-reader (see below) + --dump Write output to file instead of stdout + --output Output: json|jsonl|csv|txt|sql (default: json) + --with-text Include source text alongside each embedding + -b, --batch-size Texts per worker batch (default: 32) + -c, --concurrency Parallel workers (default: 2) + --mode process|thread Worker mode (default: process) + -p, --pooling mean|cls|none (default: mean) + --no-normalize Disable L2 normalisation + --dtype Quantization: fp32|fp16|q8|q4|q4f16|auto + --token Hugging Face API token (or set HF_TOKEN env) + --cache-dir Model cache directory (default: ~/.embedeer/models) + --device Compute device: auto|cpu|gpu (default: cpu) + --provider Execution provider override: cpu|cuda|dml + -h, --help Show this help ``` --- -## CLI +## Input Sources + +Texts can be provided in any of these ways (checked in order): + +| Source | How | +|--------|-----| +| Inline args | `--data "text1" "text2" "text3"` | +| Inline JSON | `--data '["text1","text2"]'` | +| File | `--file texts.txt` (JSON array or one record per line) | +| Stdin | Pipe or redirect — auto-detected; TTY is skipped | +| Interactive | `--interactive` / `-i` — line-reader, embeds as you type | + +**Stdin auto-detection:** when `stdin` is not a TTY (i.e. data is piped or redirected), embedeer reads it before deciding what to do. JSON arrays are accepted directly; otherwise records are split on the delimiter. + +--- + +## Interactive Line-Reader Mode (`-i` / `--interactive`) + +The interactive mode opens a line-by-line reader that starts embedding as records arrive — ideal for pasting large datasets into a terminal or streaming data from another process. +```bash +# Open an interactive session (paste lines, Ctrl+D when done) +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --interactive --dump embeddings.jsonl + +# Stream a large file through interactive mode with CSV output +cat big.txt | npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 \ + --interactive --output csv --dump embeddings.csv + +# Interactive with GPU, custom batch size, txt output +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 \ + --interactive --device auto --batch-size 16 --output txt --dump vecs.txt ``` -npx embedeer [options] -Model management (pull / cache model): - npx embedeer --model +**How it works:** -Embed texts: - npx embedeer --model --data "text1" "text2" ... - npx embedeer --model --data '["text1","text2"]' - npx embedeer --model --file texts.txt - echo '["t1","t2"]' | npx embedeer --model +| Event | What happens | +|-------|-------------| +| Type a line, press Enter | Record is buffered | +| Buffer reaches `--batch-size` | Auto-flush: embed + append to output | +| Type an empty line | Manual flush: embed whatever is buffered | +| Ctrl+D (EOF) | Flush remaining records and exit | +| Ctrl+C | Flush remaining records and exit | -Options: - -m, --model Hugging Face model (default: Xenova/all-MiniLM-L6-v2) - -d, --data Text(s) or JSON array to embed - --file Input file: JSON array or one text per line - --dump Write output to file instead of stdout - --output json|txt|sql Output format (default: json) - -b, --batch-size Texts per worker batch (default: 32) - -c, --concurrency Parallel workers (default: 2) - --mode process|thread Worker mode (default: process) - -p, --pooling mean|cls|none (default: mean) - --no-normalize Disable L2 normalisation - --dtype Quantization: fp32|fp16|q8|q4|q4f16|auto - --token Hugging Face API token (or set HF_TOKEN env) - --cache-dir Model cache directory (default: ~/.embedeer/models) - -h, --help Show this help -``` - -### Examples +**Behaviour notes:** + +- Progress messages (`Batch N: M record(s) → file`) always go to **stderr** — they never pollute piped output. +- When stdin is a TTY, a `> ` prompt is shown on stderr. +- Output defaults to **stdout** if `--dump` is omitted; a tip is printed when running in TTY mode. +- `--output json` and `--output sql` are automatically promoted to `jsonl` since they produce complete documents that cannot be appended to incrementally. +- `--output csv` writes the dimension header (`text,dim_0,dim_1,...`) on the first batch only; subsequent batches append data rows. +- Each interactive session **clears** the `--dump` file on start so you always get a fresh output file. + +### Configurable delimiter (`-D` / `--delimiter`) + +By default records in stdin and files are split on newline (`\n`). Use `--delimiter` to change it: + +```bash +# Newline-delimited (default) +printf 'Hello\nWorld\n' | npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 + +# Null-byte delimited — safe with filenames/texts that contain newlines +printf 'Hello\0World\0' | npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --delimiter '\0' + +# Tab-delimited +printf 'Hello\tWorld' | npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --delimiter '\t' + +# Custom multi-character delimiter +printf 'Hello|||World|||Foo' | npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --delimiter '|||' + +# File with null-byte delimiter +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --file records.bin --delimiter '\0' + +# Integrate with find -print0 (handles filenames with spaces / newlines) +find ./docs -name '*.txt' -print0 | \ + xargs -0 cat | \ + npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --delimiter '\0' +``` + +Supported escape sequences in `--delimiter`: + +| Sequence | Character | +|----------|-----------| +| `\0` | Null byte (U+0000) | +| `\n` | Newline (U+000A) | +| `\t` | Tab (U+0009) | +| `\r` | Carriage return (U+000D) | + +--- + +## Output Formats + +| Format | Description | +|--------|-------------| +| `json` (default) | JSON array of float arrays: `[[0.1,0.2,...],[...]]` | +| `json --with-text` | JSON array of objects: `[{"text":"...","embedding":[...]}]` | +| `jsonl` | Newline-delimited JSON, one object per line: `{"text":"...","embedding":[...]}` | +| `csv` | CSV with header: `text,dim_0,dim_1,...,dim_N` | +| `txt` | Space-separated floats, one vector per line | +| `txt --with-text` | Tab-separated: `\t` | +| `sql` | `INSERT INTO embeddings (text, vector) VALUES ...;` | + +Use `--dump ` to write the output to a file instead of stdout. Progress messages always go to stderr so they never interfere with piped output. + +### Piping examples + +```bash +MODEL=Xenova/all-MiniLM-L6-v2 + +# --- json (default) --- +# Embed and pretty-print with jq +echo '["Hello","World"]' | npx @jsilvanus/embedeer --model $MODEL | jq '.[0] | length' + +# --- jsonl --- +# One object per line — pipe to jq, grep, awk, etc. +npx @jsilvanus/embedeer --model $MODEL --data "foo" "bar" --output jsonl + +# Filter by similarity: extract embedding for downstream processing +npx @jsilvanus/embedeer --model $MODEL --data "query text" --output jsonl \ + | jq -c '.embedding' + +# Stream a large file and store as JSONL +npx @jsilvanus/embedeer --model $MODEL --file big.txt --output jsonl --dump out.jsonl + +# --- json --with-text --- +# Keep the source text next to each vector (useful for building a search index) +npx @jsilvanus/embedeer --model $MODEL --output json --with-text \ + --data "cat" "dog" "fish" \ + | jq '.[] | {text, dims: (.embedding | length)}' + +# --- csv --- +# Embed then open in Python/pandas +npx @jsilvanus/embedeer --model $MODEL --file texts.txt --output csv --dump vectors.csv +python3 -c "import pandas as pd; df = pd.read_csv('vectors.csv'); print(df.shape)" + +# --- txt --- +# Raw floats — useful for awk/paste/numpy text loading +npx @jsilvanus/embedeer --model $MODEL --data "Hello" "World" --output txt \ + | awk '{print NF, "dimensions"}' + +# txt --with-text: original text + tab + floats, easy to parse +npx @jsilvanus/embedeer --model $MODEL --file texts.txt --output txt --with-text \ + | while IFS=$'\t' read -r text vec; do echo "TEXT: $text"; done + +# --- sql --- +# Generate INSERT statements for a vector DB or SQLite +npx @jsilvanus/embedeer --model $MODEL --file texts.txt --output sql --dump inserts.sql +sqlite3 mydb.sqlite < inserts.sql + +# --- Chaining with other tools --- +# Embed stdin from another command +cat docs/*.txt | npx @jsilvanus/embedeer --model $MODEL --output jsonl > embeddings.jsonl + +# Null-byte input from find (handles any filename or text with newlines) +find ./corpus -name '*.txt' -print0 \ + | xargs -0 cat \ + | npx @jsilvanus/embedeer --model $MODEL --delimiter '\0' --output jsonl +``` + +--- + +### CLI Examples ```bash # Pull a model (like ollama pull) -npx embedeer --model Xenova/all-MiniLM-L6-v2 +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 + +# Embed a few strings, output JSON (CPU) +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --data "Hello" "World" + +# Auto-detect GPU, fall back to CPU if unavailable +# (uses CUDA on Linux, DirectML on Windows, CPU everywhere else) +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --device auto --data "Hello" + +# Require GPU (throws with install instructions if no provider found) +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --device gpu --data "Hello GPU" + +# Explicit CUDA (Linux x64 — requires CUDA 12 system libraries) +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --provider cuda --data "Hello CUDA" -# Embed a few strings, output JSON -npx embedeer --model Xenova/all-MiniLM-L6-v2 --data "Hello" "World" +# Explicit DirectML (Windows x64) +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --provider dml --data "Hello DML" # Embed from a file, dump SQL to disk -npx embedeer --model Xenova/all-MiniLM-L6-v2 \ +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 \ --file texts.txt --output sql --dump out.sql # Use quantized model, in-process threads, private model with token -npx embedeer --model my-org/private-model \ +npx @jsilvanus/embedeer --model my-org/private-model \ --token hf_xxx --dtype q8 --mode thread \ --data "embed me" ``` --- +### Using GPU + +No additional packages are needed — `onnxruntime-node` (installed with `@jsilvanus/embedeer`) already +bundles the CUDA provider on Linux x64 and DirectML on Windows x64. + +**Linux x64 — NVIDIA CUDA:** + +```bash +# One-time: install CUDA 12 system libraries (Ubuntu/Debian) +sudo apt install cuda-toolkit-12-6 libcudnn9-cuda-12 + +# Auto-detect: uses CUDA here, CPU fallback on any other machine +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --device auto --data "Hello" + +# Hard-require CUDA (throws with diagnostic error if unavailable): +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --device gpu --data "Hello GPU" + +# Explicit CUDA provider: +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --provider cuda --data "Hello CUDA" +``` + +**Windows x64 — DirectML (any GPU: NVIDIA / AMD / Intel):** + +```bash +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --device auto --data "Hello" +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --device gpu --data "Hello GPU" +npx @jsilvanus/embedeer --model Xenova/all-MiniLM-L6-v2 --provider dml --data "Hello DML" +``` + +--- + +## GPU Acceleration + +GPU support is built into `onnxruntime-node` (a dependency of `@huggingface/transformers`): + +| Platform | Provider | Requirement | +|----------------|-----------|--------------------------------------------------------| +| Linux x64 | CUDA | NVIDIA GPU + driver ≥ 525, CUDA 12 toolkit, cuDNN 9 | +| Windows x64 | DirectML | Any DirectX 12 GPU (most GPUs since 2016), Windows 10+ | + +### Provider selection logic + +| `device` | `provider` | Behavior | +|----------|-----------|----------| +| `cpu` (default) | — | Always CPU | +| `auto` | — | Try GPU providers for the platform in order; silent CPU fallback | +| `gpu` | — | Try GPU providers; **throw** if none available | +| any | `cuda` | Load CUDA provider; **throw** if not available or not supported | +| any | `dml` | Load DirectML provider; **throw** if not available or not supported | +| any | `cpu` | Always CPU | + +On Linux x64: GPU order is `cuda`. +On Windows x64: GPU order is `cuda → dml`. + +--- + ## How it works ``` @@ -138,17 +400,18 @@ embed(texts) │ └─ Promise.all(batches) ──► WorkerPool │ - ├─ [process mode] ChildProcessWorker 0 → batch A - ├─ [process mode] ChildProcessWorker 1 → batch B - │ (OS-level isolation; crash → reject only that task) + ├─ [process mode] ChildProcessWorker 0 + │ resolveProvider(device, provider) + │ → pipeline('feature-extraction', model, { device: 'cuda' }) + │ → embed batch A │ - ├─ [thread mode] ThreadWorker 0 → batch A - └─ [thread mode] ThreadWorker 1 → batch B + └─ [process mode] ChildProcessWorker 1 + resolveProvider(device, provider) + → pipeline(...) → embed batch B ``` -Workers load the model **once** at startup and reuse it for all batches, avoiding -repeated download overhead. Models are cached in `~/.embedeer/models` so -subsequent runs start instantly. +Workers load the model **once** at startup and reuse it for all batches. +Provider activation happens per-worker before the pipeline is created. --- @@ -158,5 +421,4 @@ subsequent runs start instantly. npm test ``` -Tests use Node's built-in `node:test` runner. Worker behaviour is tested with -lightweight fake/echo workers — no real model download required. +Tests use Node's built-in `node:test` runner. No real model download required. diff --git a/package-lock.json b/package-lock.json index 12e891f..3efd3e2 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,11 +1,11 @@ { - "name": "embedeer", + "name": "@jsilvanus/embedeer", "version": "1.0.0", "lockfileVersion": 3, "requires": true, "packages": { "": { - "name": "embedeer", + "name": "@jsilvanus/embedeer", "version": "1.0.0", "license": "ISC", "dependencies": { @@ -13,6 +13,9 @@ }, "bin": { "embedeer": "src/cli.js" + }, + "engines": { + "node": ">=18" } }, "node_modules/@emnapi/runtime": { @@ -1039,6 +1042,51 @@ "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.18.2.tgz", "integrity": "sha512-AsuCzffGHJybSaRrmr5eHr81mwJU3kjw6M+uprWvCXiNeN9SOGwQ3Jn8jb8m3Z6izVgknn1R0FTCEAP2QrLY/w==", "license": "MIT" + }, + "packages/embedeer": { + "name": "@jsilvanus/embedeer", + "version": "1.0.0", + "extraneous": true, + "license": "ISC", + "dependencies": { + "@huggingface/transformers": "^4.0.1" + }, + "bin": { + "embedeer": "src/cli.js" + }, + "engines": { + "node": ">=18" + } + }, + "packages/ort-linux-x64-cuda": { + "name": "@jsilvanus/embedeer-ort-linux-x64-cuda", + "version": "1.0.0", + "extraneous": true, + "hasInstallScript": true, + "license": "ISC", + "engines": { + "node": ">=18" + } + }, + "packages/ort-win32-x64-cuda": { + "name": "@embedeer/ort-win32-x64-cuda", + "version": "1.0.0", + "extraneous": true, + "hasInstallScript": true, + "license": "ISC", + "engines": { + "node": ">=18" + } + }, + "packages/ort-win32-x64-dml": { + "name": "@jsilvanus/embedeer-ort-win32-x64-dml", + "version": "1.0.0", + "extraneous": true, + "hasInstallScript": true, + "license": "ISC", + "engines": { + "node": ">=18" + } } } } diff --git a/package.json b/package.json index 7193420..bc7829e 100644 --- a/package.json +++ b/package.json @@ -1,14 +1,17 @@ { - "name": "embedeer", + "name": "@jsilvanus/embedeer", "version": "1.0.0", - "description": "A node.js embedding tool", + "description": "A node.js embedding tool with optional GPU acceleration", "main": "src/index.js", "bin": { "embedeer": "src/cli.js" }, + "files": [ + "src", + "README.md" + ], "scripts": { - "test": "node --test test/*.test.js", - "start": "node src/cli.js" + "test": "node --test test/*.test.js" }, "repository": { "type": "git", @@ -19,15 +22,24 @@ "huggingface", "nlp", "transformers", - "parallel" + "parallel", + "gpu", + "cuda", + "onnxruntime" ], - "author": "", + "author": "jsilvanus", "license": "ISC", "type": "module", + "engines": { + "node": ">=18" + }, "bugs": { "url": "https://github.com/jsilvanus/embedeer/issues" }, "homepage": "https://github.com/jsilvanus/embedeer#readme", + "publishConfig": { + "access": "public" + }, "dependencies": { "@huggingface/transformers": "^4.0.1" } diff --git a/src/cli.js b/src/cli.js old mode 100644 new mode 100755 index 62d8fce..cd415c9 --- a/src/cli.js +++ b/src/cli.js @@ -5,32 +5,44 @@ * Model management: * embedeer --model Pull / cache a model * - * Embedding: + * Embedding (batch): * embedeer --model --data "text1" "text2" ... * embedeer --model --data '["text1","text2"]' * embedeer --model --file texts.txt * echo '["t1","t2"]' | embedeer --model + * printf 'a\0b\0c' | embedeer --model --delimiter '\0' + * + * Interactive / streaming line-reader: + * embedeer --model --interactive --dump out.jsonl + * cat big.txt | embedeer --model --interactive --output csv --dump out.csv * * Options: - * -m, --model Hugging Face model (default: Xenova/all-MiniLM-L6-v2) - * -d, --data Text(s) to embed (JSON array or individual strings) - * --file File of texts (JSON array or one text per line) - * --dump Write output to file instead of stdout - * --output json|txt|sql Output format (default: json) - * -b, --batch-size Texts per worker batch (default: 32) - * -c, --concurrency Parallel worker processes/threads (default: 2) - * --mode process|thread Worker mode: isolated processes or in-process threads (default: process) - * -p, --pooling Pooling: mean|cls|none (default: mean) - * --no-normalize Disable L2 normalisation - * --dtype Quantization: fp32|fp16|q8|q4|q4f16|auto - * --token Hugging Face API token (overrides HF_TOKEN env var) - * --cache-dir Custom model cache directory (default: ~/.embedeer/models) - * -h, --help Show this help + * -m, --model Hugging Face model (default: Xenova/all-MiniLM-L6-v2) + * -d, --data Text(s) to embed (JSON array or individual strings) + * --file File of texts (JSON array or one text per line) + * -D, --delimiter Record separator for stdin/file input (default: \n) + * Escape sequences: \0 (null byte), \n, \t, \r + * -i, --interactive Interactive line-reader: embed as lines arrive + * --dump Write output to file instead of stdout + * --output Output format: json|jsonl|csv|txt|sql (default: json) + * --with-text Include source text in json/txt output + * -b, --batch-size Texts per worker batch (default: 32) + * -c, --concurrency Parallel worker processes/threads (default: 2) + * --mode process|thread Worker mode (default: process) + * -p, --pooling Pooling: mean|cls|none (default: mean) + * --no-normalize Disable L2 normalisation + * --dtype Quantization: fp32|fp16|q8|q4|q4f16|auto + * --token Hugging Face API token (overrides HF_TOKEN env var) + * --cache-dir Custom model cache directory (default: ~/.embedeer/models) + * --device Compute device: auto|cpu|gpu (default: cpu) + * --provider Execution provider override: cpu|cuda|dml + * -h, --help Show this help */ import { Embedder } from './embedder.js'; import { getCacheDir, DEFAULT_CACHE_DIR } from './model-cache.js'; -import { readFileSync, writeFileSync } from 'fs'; +import { readFileSync, writeFileSync, appendFileSync } from 'fs'; +import readline from 'readline'; // ── Argument parsing ──────────────────────────────────────────────────────── @@ -43,26 +55,38 @@ embedeer — parallel batched embeddings from Hugging Face Model management (pull / cache): embedeer --model -Embedding: +Embedding (batch): embedeer --model [--data "text1" "text2" ...] embedeer --model --file texts.txt echo '["t1","t2"]' | embedeer --model + printf 'a\\0b\\0c' | embedeer --model --delimiter '\\0' + +Interactive / streaming line-reader: + embedeer --model --interactive --dump out.jsonl + cat big.txt | embedeer --model -i --output csv --dump out.csv Options: - -m, --model Hugging Face model (default: Xenova/all-MiniLM-L6-v2) - -d, --data Text(s) or JSON array to embed - --file Input file: JSON array or one text per line - --dump Write output to file instead of stdout - --output json|txt|sql Output format (default: json) - -b, --batch-size Texts per worker batch (default: 32) - -c, --concurrency Parallel workers (default: 2) - --mode process|thread Worker mode (default: process) - -p, --pooling mean|cls|none (default: mean) - --no-normalize Disable L2 normalisation - --dtype Quantization: fp32|fp16|q8|q4|q4f16|auto - --token Hugging Face API token - --cache-dir Model cache directory (default: ${DEFAULT_CACHE_DIR}) - -h, --help Show this help + -m, --model Hugging Face model (default: Xenova/all-MiniLM-L6-v2) + -d, --data Text(s) or JSON array to embed + --file Input file: JSON array or delimited texts + -D, --delimiter Record separator for stdin/file (default: \\n) + Escape sequences supported: \\0 \\n \\t \\r + -i, --interactive Interactive line-reader: embed as lines arrive + (empty line or full batch triggers immediate flush) + --dump Write output to file instead of stdout + --output Output: json|jsonl|csv|txt|sql (default: json) + --with-text Include source text alongside each embedding + -b, --batch-size Texts per worker batch (default: 32) + -c, --concurrency Parallel workers (default: 2) + --mode process|thread Worker mode (default: process) + -p, --pooling mean|cls|none (default: mean) + --no-normalize Disable L2 normalisation + --dtype Quantization: fp32|fp16|q8|q4|q4f16|auto + --token Hugging Face API token + --cache-dir Model cache directory (default: ${DEFAULT_CACHE_DIR}) + --device Compute device: auto|cpu|gpu (default: cpu) + --provider Execution provider override: cpu|cuda|dml + -h, --help Show this help `.trim()); } @@ -70,14 +94,20 @@ Options: // --data so that negative numbers or hyphen-prefixed strings work correctly. const KNOWN_FLAGS = new Set([ '--help', '-h', '--model', '-m', '--data', '-d', '--file', '--dump', - '--output', '--batch-size', '-b', '--concurrency', '-c', '--mode', - '--pooling', '-p', '--no-normalize', '--dtype', '--token', '--cache-dir', + '--output', '--with-text', '--batch-size', '-b', '--concurrency', '-c', + '--mode', '--pooling', '-p', '--no-normalize', '--dtype', '--token', + '--cache-dir', '--device', '--provider', '--delimiter', '-D', + '--interactive', '-i', ]); +const options = { model: 'Xenova/all-MiniLM-L6-v2', - data: null, // --data texts (array) - file: null, // --file path - dump: null, // --dump path - output: 'json', // json | txt | sql + data: null, // --data texts (array) + file: null, // --file path + delimiter: '\n', // --delimiter record separator for stdin/file + interactive: false, // --interactive / -i: line-reader mode + dump: null, // --dump path + output: 'json', // json | jsonl | csv | txt | sql + withText: false, // --with-text: include source text in output batchSize: 32, concurrency: 2, mode: 'process', @@ -86,6 +116,8 @@ const KNOWN_FLAGS = new Set([ dtype: undefined, token: undefined, cacheDir: undefined, + device: undefined, + provider: undefined, }; const positional = []; @@ -106,10 +138,16 @@ for (let i = 0; i < args.length; i++) { } } else if (arg === '--file') { options.file = args[++i]; + } else if (arg === '--delimiter' || arg === '-D') { + options.delimiter = parseDelimiter(args[++i]); + } else if (arg === '--interactive' || arg === '-i') { + options.interactive = true; } else if (arg === '--dump') { options.dump = args[++i]; } else if (arg === '--output') { options.output = args[++i]; + } else if (arg === '--with-text') { + options.withText = true; } else if (arg === '--batch-size' || arg === '-b') { options.batchSize = parseInt(args[++i], 10); } else if (arg === '--concurrency' || arg === '-c') { @@ -126,6 +164,10 @@ for (let i = 0; i < args.length; i++) { options.token = args[++i]; } else if (arg === '--cache-dir') { options.cacheDir = args[++i]; + } else if (arg === '--device') { + options.device = args[++i]; + } else if (arg === '--provider') { + options.provider = args[++i]; } else { positional.push(arg); } @@ -133,9 +175,28 @@ for (let i = 0; i < args.length; i++) { // ── Output formatting ─────────────────────────────────────────────────────── -function formatOutput(texts, embeddings, format) { +function formatOutput(texts, embeddings, format, withText) { switch (format) { + case 'jsonl': + return texts + .map((text, i) => JSON.stringify({ text, embedding: embeddings[i] })) + .join('\n'); + + case 'csv': { + if (embeddings.length === 0) return ''; + const dims = embeddings[0].length; + const header = ['text', ...Array.from({ length: dims }, (_, k) => `dim_${k}`)].join(','); + const rows = texts.map((text, i) => { + const safeText = '"' + text.replace(/"/g, '""') + '"'; + return [safeText, ...embeddings[i]].join(','); + }); + return [header, ...rows].join('\n'); + } + case 'txt': + if (withText) { + return texts.map((text, i) => `${text}\t${embeddings[i].join(' ')}`).join('\n'); + } return embeddings.map((vec) => vec.join(' ')).join('\n'); case 'sql': { @@ -152,6 +213,11 @@ function formatOutput(texts, embeddings, format) { } default: // json + if (withText) { + return JSON.stringify( + texts.map((text, i) => ({ text, embedding: embeddings[i] })) + ); + } return JSON.stringify(embeddings); } } @@ -167,13 +233,29 @@ function writeOutput(content, dumpPath) { // ── Input reading ─────────────────────────────────────────────────────────── -function parseTexts(raw) { +/** + * Convert a user-supplied delimiter string, resolving common escape sequences. + * Supports: \0 (null byte), \n (newline), \t (tab), \r (carriage return). + */ +export function parseDelimiter(str) { + return str + .replace(/\\0/g, '\0') + .replace(/\\n/g, '\n') + .replace(/\\t/g, '\t') + .replace(/\\r/g, '\r'); +} + +/** + * Parse a block of text into an array of strings. + * First tries to parse as a JSON array; if that fails, splits on `delimiter`. + */ +export function parseTexts(raw, delimiter = '\n') { try { const parsed = JSON.parse(raw); if (!Array.isArray(parsed)) throw new Error('Expected a JSON array'); return parsed; } catch { - return raw.split('\n').filter(Boolean); + return raw.split(delimiter).filter(Boolean); } } @@ -189,11 +271,176 @@ async function readStdin() { }); } -// ── Main ──────────────────────────────────────────────────────────────────── +// ── Interactive / streaming line-reader mode ──────────────────────────────── + +/** + * Interactive mode: read one text record per line from stdin, embed in + * configurable batches, and stream results to a file (or stdout). + * + * Flushing triggers: + * • Batch reaches --batch-size lines (auto-flush) + * • User types an empty line (manual flush) + * • EOF / Ctrl+D (flush remaining records and exit) + * • Ctrl+C (flush remaining records and exit) + * + * Output: + * • Formats json and sql are not appendable — they are promoted to jsonl. + * • csv writes the dimension header once (on the first batch). + * • All other formats append each batch as independent lines. + * • Progress/prompt messages always go to stderr. + */ +async function runInteractive(cacheDir) { + // json and sql produce complete documents that can't be appended to + // incrementally; switch to jsonl so each batch emits self-contained lines. + if (options.output === 'json' || options.output === 'sql') { + console.error( + `Warning: --output ${options.output} is not suitable for interactive mode. Switching to jsonl.` + ); + options.output = 'jsonl'; + } + + const isTTY = process.stdin.isTTY; + const outputFile = options.dump; + + if (isTTY && !outputFile) { + console.error( + 'Tip: use --dump to write output to a file so it does not mix with input.' + ); + } + + // Load the model before opening the reader so we are ready to embed immediately. + console.error(`Loading model: ${options.model}…`); + const embedder = await Embedder.create(options.model, { + batchSize: options.batchSize, + concurrency: options.concurrency, + mode: options.mode, + pooling: options.pooling, + normalize: options.normalize, + dtype: options.dtype, + token: options.token, + cacheDir, + device: options.device, + provider: options.provider, + }); + + if (isTTY) { + console.error('Model ready. Paste records below, one per line.'); + console.error(`Batch size: ${options.batchSize}. Empty line = flush now. Ctrl+D = done. Ctrl+C = abort.`); + } + + // Initialise / clear the output file so each interactive session starts fresh. + if (outputFile) { + writeFileSync(outputFile, '', 'utf8'); + } + + let csvHeaderWritten = false; + let batch = []; + let batchNumber = 0; + let flushing = false; + + /** + * Embed the current batch and write its output. + * The readline interface must be paused before calling this. + */ + async function flushBatch() { + if (batch.length === 0) return; + const texts = [...batch]; + batch = []; + batchNumber++; + + const embeddings = await embedder.embed(texts); + let content; + + if (options.output === 'csv') { + const full = formatOutput(texts, embeddings, 'csv', options.withText); + if (!csvHeaderWritten) { + content = full; // includes header + csvHeaderWritten = true; + } else { + content = full.split('\n').slice(1).join('\n'); // data rows only + } + } else { + content = formatOutput(texts, embeddings, options.output, options.withText); + } + + if (outputFile) { + appendFileSync(outputFile, content + '\n', 'utf8'); + console.error(`Batch ${batchNumber}: ${texts.length} record(s) → ${outputFile}`); + } else { + process.stdout.write(content + '\n'); + } + } + + const rl = readline.createInterface({ + input: process.stdin, + // Route the prompt to stderr so it never pollutes stdout embeddings. + output: isTTY ? process.stderr : null, + terminal: isTTY, + }); + + if (isTTY) rl.prompt(); + + rl.on('line', (line) => { + const text = line.trim(); + + if (text !== '') { + batch.push(text); + } + + const shouldFlush = text === '' || batch.length >= options.batchSize; + + if (shouldFlush && !flushing && batch.length > 0) { + flushing = true; + rl.pause(); + flushBatch() + .then(() => { + flushing = false; + rl.resume(); + if (isTTY) rl.prompt(); + }) + .catch((err) => { + console.error('Error embedding batch:', err.message); + flushing = false; + rl.resume(); + if (isTTY) rl.prompt(); + }); + } else if (isTTY) { + rl.prompt(); + } + }); + + await new Promise((resolve) => { + rl.on('close', async () => { + try { + await flushBatch(); + } catch (err) { + console.error('Error embedding final batch:', err.message); + } + await embedder.destroy(); + if (outputFile) { + console.error(`Done. ${batchNumber} batch(es) written to ${outputFile}`); + } + resolve(); + }); + + // Handle Ctrl+C — flush remaining records then exit cleanly. + rl.on('SIGINT', () => { + console.error('\nInterrupted — flushing remaining records…'); + rl.close(); // triggers 'close' event above + }); + }); +} + + async function main() { const resolvedCacheDir = getCacheDir(options.cacheDir); + // ── Interactive line-reader mode ───────────────────────────────────────── + if (options.interactive) { + return runInteractive(resolvedCacheDir); + } + // ── Model-only mode (pull / cache) ────────────────────────────────────── const hasDataSource = options.data || options.file || positional.length > 0; if (!hasDataSource) { @@ -211,7 +458,7 @@ async function main() { return; } // Stdin provided — treat as text input. - const texts = parseTexts(stdinRaw); + const texts = parseTexts(stdinRaw, options.delimiter); return runEmbedding(texts, resolvedCacheDir); } @@ -220,7 +467,7 @@ async function main() { if (options.file) { const raw = readFileSync(options.file, 'utf8').trim(); - texts = parseTexts(raw); + texts = parseTexts(raw, options.delimiter); } else if (options.data && options.data.length > 0) { // --data may be a JSON array in a single arg or multiple plain strings if (options.data.length === 1) { @@ -250,11 +497,13 @@ async function runEmbedding(texts, cacheDir) { dtype: options.dtype, token: options.token, cacheDir, + device: options.device, + provider: options.provider, }); try { const embeddings = await embedder.embed(texts); - const content = formatOutput(texts, embeddings, options.output); + const content = formatOutput(texts, embeddings, options.output, options.withText); writeOutput(content, options.dump); } finally { await embedder.destroy(); diff --git a/src/embedder.js b/src/embedder.js index 2310829..e01ceac 100644 --- a/src/embedder.js +++ b/src/embedder.js @@ -29,6 +29,8 @@ export class Embedder { * @param {string} [options.token] Hugging Face API token (overrides HF_TOKEN env) * @param {string} [options.dtype] Quantization dtype ('fp32'|'fp16'|'q8'|'q4'|'q4f16'|'auto') * @param {string} [options.cacheDir] Custom model cache directory + * @param {string} [options.device] Compute device: 'auto'|'cpu'|'gpu' (default: 'cpu') + * @param {string} [options.provider] Execution provider override: 'cpu'|'cuda'|'dml' */ constructor(modelName = 'Xenova/all-MiniLM-L6-v2', options = {}) { this.modelName = modelName; @@ -41,6 +43,8 @@ export class Embedder { token: options.token, dtype: options.dtype, cacheDir: options.cacheDir ?? getCacheDir(), + device: options.device, + provider: options.provider, }); } diff --git a/src/provider-loader.js b/src/provider-loader.js new file mode 100644 index 0000000..e4c80ba --- /dev/null +++ b/src/provider-loader.js @@ -0,0 +1,249 @@ +/** + * Provider loader — selects and activates an ONNX Runtime execution provider + * before @huggingface/transformers creates its pipeline. + * + * onnxruntime-node (a transitive dependency of @huggingface/transformers@4.x) + * already ships the CUDA provider on Linux x64 and DirectML on Windows x64 with + * no additional packages needed. This module performs the necessary system checks + * (NVIDIA driver, CUDA libraries) and returns the device string to pass to + * pipeline(). + * + * Usage: + * import { resolveProvider } from './provider-loader.js'; + * const deviceStr = await resolveProvider(device, provider); + * // pass deviceStr to pipeline() if truthy + */ + +import { execSync } from 'child_process'; +import { existsSync } from 'fs'; + +// ── CUDA (linux/x64) ───────────────────────────────────────────────────────── + +/** + * Shared libraries required by libonnxruntime_providers_cuda.so (CUDA 12 / cuDNN 9). + * These are system-installed libraries; they are NOT bundled with onnxruntime-node. + */ +const REQUIRED_CUDA_LIBS = [ + 'libcudart.so.12', + 'libcublas.so.12', + 'libcublasLt.so.12', + 'libcurand.so.10', + 'libcufft.so.11', + 'libcudnn.so.9', +]; + +/** + * Common directories where CUDA libraries may be installed. + * Includes entries from LD_LIBRARY_PATH so custom installs are detected. + * @returns {string[]} + */ +function cudaSearchDirs() { + const extra = (process.env.LD_LIBRARY_PATH ?? '').split(':').filter(Boolean); + return [ + '/usr/local/cuda/lib64', + '/usr/local/cuda-12/lib64', + '/usr/local/cuda-12.0/lib64', + '/usr/local/cuda-12.1/lib64', + '/usr/local/cuda-12.2/lib64', + '/usr/local/cuda-12.3/lib64', + '/usr/local/cuda-12.4/lib64', + '/usr/local/cuda-12.5/lib64', + '/usr/local/cuda-12.6/lib64', + '/usr/lib/x86_64-linux-gnu', + '/usr/lib64', + ...extra, + ]; +} + +/** + * Find a shared library by name. Checks common CUDA paths then falls back to + * `ldconfig -p` for libraries registered in the dynamic linker cache. + * + * @param {string} libName e.g. 'libcudart.so.12' + * @returns {string|null} Path to the library, or null if not found. + */ +function findLib(libName) { + for (const dir of cudaSearchDirs()) { + if (existsSync(`${dir}/${libName}`)) return `${dir}/${libName}`; + } + try { + const output = execSync('ldconfig -p', { + stdio: ['ignore', 'pipe', 'ignore'], + encoding: 'utf8', + timeout: 3000, + }); + for (const line of output.split('\n')) { + if (line.includes(libName) && line.includes('=>')) { + const match = line.match(/=>\s*(.+)/); + if (match) return match[1].trim(); + } + } + } catch { + // ldconfig not available in all environments + } + return null; +} + +/** + * Activate the CUDA execution provider. + * Checks for NVIDIA GPU driver and required CUDA 12 / cuDNN 9 system libraries. + * + * @returns {Promise} + * @throws {Error} If NVIDIA GPU is not detected or required CUDA libraries are missing. + */ +async function activateCuda() { + if (!existsSync('/dev/nvidiactl')) { + throw new Error( + 'No NVIDIA GPU detected (/dev/nvidiactl not found).\n' + + 'Ensure NVIDIA drivers are installed. Verify with: nvidia-smi', + ); + } + + const missing = REQUIRED_CUDA_LIBS.filter((lib) => findLib(lib) === null); + if (missing.length > 0) { + throw new Error( + `Missing CUDA system libraries: ${missing.join(', ')}\n\n` + + 'onnxruntime-node CUDA requires CUDA 12 + cuDNN 9. Install them:\n\n' + + ' # Option A — CUDA 12 + cuDNN 9 via apt (Ubuntu/Debian)\n' + + ' sudo apt install cuda-toolkit-12-6 libcudnn9-cuda-12\n\n' + + ' # Option B — CUDA Toolkit installer from NVIDIA\n' + + ' https://developer.nvidia.com/cuda-downloads\n' + + ' https://developer.nvidia.com/cudnn-downloads\n\n' + + ' # After installing, make sure libraries are on LD_LIBRARY_PATH if non-standard:\n' + + ' export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH', + ); + } +} + +// ── DirectML (win32/x64) ───────────────────────────────────────────────────── + +/** + * Activate the DirectML execution provider. + * DirectML is bundled with onnxruntime-node on Windows. Just verifies platform. + * + * @returns {Promise} + * @throws {Error} If not running on Windows. + */ +async function activateDml() { + if (process.platform !== 'win32') { + throw new Error( + `DirectML is only available on Windows (current platform: ${process.platform}).`, + ); + } +} + +// ── Internal provider map ──────────────────────────────────────────────────── + +/** + * Internal map of "--" to inline activation logic. + * Replacing the old external-package-per-provider pattern since onnxruntime-node + * already bundles the CUDA and DirectML providers. + * + * @type {Record Promise, getDevice: () => string }>} + */ +const PROVIDER_IMPLS = { + 'linux-x64-cuda': { activate: activateCuda, getDevice: () => 'cuda' }, + 'win32-x64-dml': { activate: activateDml, getDevice: () => 'dml' }, +}; + +// ── Public API ─────────────────────────────────────────────────────────────── + +/** + * Returns the ordered list of preferred GPU providers for the current platform. + * @returns {string[]} + */ +export function getPlatformDefaultProviders() { + const platform = process.platform; + const arch = process.arch; + if (platform === 'linux' && arch === 'x64') return ['cuda']; + if (platform === 'win32' && arch === 'x64') return ['cuda', 'dml']; + return []; +} + +/** + * Attempt to activate a specific provider. Returns a result object: + * - { loaded: true, deviceStr, error: null } — provider ready + * - { loaded: false, deviceStr: null, error } — provider unavailable + * + * @param {string} provider e.g. 'cuda' or 'dml' + * @returns {Promise<{loaded: boolean, deviceStr: string|null, error: Error|null}>} + */ +export async function tryLoadProvider(provider) { + const key = `${process.platform}-${process.arch}-${provider}`; + const impl = PROVIDER_IMPLS[key]; + if (!impl) { + return { loaded: false, deviceStr: null, error: null }; + } + try { + await impl.activate(); + const deviceStr = impl.getDevice(); + return { loaded: true, deviceStr, error: null }; + } catch (err) { + return { loaded: false, deviceStr: null, error: err }; + } +} + +/** + * Resolve and activate the appropriate execution provider, returning the + * device string to pass to `@huggingface/transformers` pipeline(). + * + * @param {'auto'|'cpu'|'gpu'|undefined} device + * @param {'cpu'|'cuda'|'dml'|undefined} provider Optional explicit override + * @returns {Promise} Device string or undefined (CPU default) + * + * @throws {Error} When an explicit provider is requested but not available. + * @throws {Error} When device='gpu' and no GPU provider is available. + */ +export async function resolveProvider(device, provider) { + const dev = (device ?? 'cpu').toLowerCase(); + const prov = provider ? provider.toLowerCase() : undefined; + + // --- Explicit CPU --- + if (dev === 'cpu' && !prov) return undefined; + if (prov === 'cpu') return undefined; + + // --- Explicit provider --- + if (prov && prov !== 'cpu') { + const key = `${process.platform}-${process.arch}-${prov}`; + if (!PROVIDER_IMPLS[key]) { + const supportedPlatforms = Object.keys(PROVIDER_IMPLS) + .filter((k) => k.endsWith(`-${prov}`)) + .map((k) => k.replace(`-${prov}`, '')); + throw new Error( + `Provider '${prov}' is not supported on ${process.platform}/${process.arch}. ` + + `Supported platforms: ${supportedPlatforms.join(', ') || 'none'}.`, + ); + } + + const { loaded, deviceStr, error } = await tryLoadProvider(prov); + if (!loaded) { + if (error) throw error; + throw new Error( + `Provider '${prov}' is not available on ${process.platform}/${process.arch}.`, + ); + } + return deviceStr ?? undefined; + } + + // --- device='gpu' or device='auto': try platform defaults in order --- + const candidates = getPlatformDefaultProviders(); + let lastError = null; + + for (const candidate of candidates) { + const { loaded, deviceStr, error } = await tryLoadProvider(candidate); + if (loaded) return deviceStr ?? candidate; + if (error) lastError = error; + } + + if (dev === 'gpu') { + if (lastError) throw lastError; + throw new Error( + `device='gpu' was requested but no GPU provider is available ` + + `for ${process.platform}/${process.arch}. ` + + `Supported: linux/x64 (CUDA 12 + cuDNN 9), win32/x64 (DirectML).`, + ); + } + + // device='auto' and no GPU available → silently fall back to CPU + return undefined; +} diff --git a/src/thread-worker-script.js b/src/thread-worker-script.js index e2b66b0..f34c762 100644 --- a/src/thread-worker-script.js +++ b/src/thread-worker-script.js @@ -18,8 +18,9 @@ import { workerData, parentPort } from 'worker_threads'; import { pipeline, env } from '@huggingface/transformers'; import { buildPipelineOptions } from './model-cache.js'; +import { resolveProvider } from './provider-loader.js'; -const { modelName, pooling, normalize, token, dtype, cacheDir } = workerData; +const { modelName, pooling, normalize, token, dtype, cacheDir, device, provider } = workerData; // Apply configuration before loading the model. if (token) process.env.HF_TOKEN = token; @@ -28,7 +29,13 @@ if (cacheDir) env.cacheDir = cacheDir; let extractor; async function init() { - extractor = await pipeline('feature-extraction', modelName, buildPipelineOptions(dtype)); + // Activate GPU provider (if requested) before creating the pipeline. + const deviceStr = await resolveProvider(device, provider); + const pipelineOpts = { + ...buildPipelineOptions(dtype), + ...(deviceStr ? { device: deviceStr } : {}), + }; + extractor = await pipeline('feature-extraction', modelName, pipelineOpts); parentPort.postMessage({ type: 'ready' }); } diff --git a/src/worker-pool.js b/src/worker-pool.js index a6ac0b5..951eb26 100644 --- a/src/worker-pool.js +++ b/src/worker-pool.js @@ -33,6 +33,8 @@ export class WorkerPool { * @param {string} [options.token] Hugging Face API token (overrides HF_TOKEN env var) * @param {string} [options.dtype] Quantization dtype ('fp32'|'fp16'|'q8'|'q4'|'q4f16'|'auto') * @param {string} [options.cacheDir] Custom model cache directory + * @param {string} [options.device] Compute device: 'auto'|'cpu'|'gpu' (default: 'cpu') + * @param {string} [options.provider] Execution provider override: 'cpu'|'cuda'|'dml' * @param {Function} [options._WorkerClass] Override worker class (for testing) */ constructor(modelName, { @@ -43,6 +45,8 @@ export class WorkerPool { token, dtype, cacheDir, + device, + provider, _WorkerClass, } = {}) { this.modelName = modelName; @@ -53,6 +57,8 @@ export class WorkerPool { this.token = token; this.dtype = dtype; this.cacheDir = cacheDir; + this.device = device; + this.provider = provider; // Pick defaults based on mode; can be overridden for testing. if (_WorkerClass) { @@ -148,6 +154,8 @@ export class WorkerPool { token: this.token, dtype: this.dtype, cacheDir: this.cacheDir, + device: this.device, + provider: this.provider, }, }); diff --git a/src/worker.js b/src/worker.js index 3bd65da..6e308f1 100644 --- a/src/worker.js +++ b/src/worker.js @@ -18,6 +18,7 @@ import { pipeline, env } from '@huggingface/transformers'; import { buildPipelineOptions } from './model-cache.js'; +import { resolveProvider } from './provider-loader.js'; let extractor; let pooling; @@ -30,7 +31,13 @@ process.on('message', async (msg) => { // Apply auth and cache config before loading the model. if (msg.token) process.env.HF_TOKEN = msg.token; if (msg.cacheDir) env.cacheDir = msg.cacheDir; - extractor = await pipeline('feature-extraction', msg.modelName, buildPipelineOptions(msg.dtype)); + // Activate GPU provider (if requested) before creating the pipeline. + const deviceStr = await resolveProvider(msg.device, msg.provider); + const pipelineOpts = { + ...buildPipelineOptions(msg.dtype), + ...(deviceStr ? { device: deviceStr } : {}), + }; + extractor = await pipeline('feature-extraction', msg.modelName, pipelineOpts); process.send({ type: 'ready' }); } catch (err) { process.send({ type: 'error', id: null, error: err.message }); diff --git a/test/cli-format.test.js b/test/cli-format.test.js index 818d29d..06fcb1b 100644 --- a/test/cli-format.test.js +++ b/test/cli-format.test.js @@ -1,5 +1,5 @@ /** - * Tests for CLI output formatting and model-cache helpers. + * Tests for CLI output formatting, input parsing, and model-cache helpers. * These are pure unit tests — no workers, no network. */ @@ -30,40 +30,192 @@ describe('model-cache', async () => { }); }); -describe('CLI output formatting', async () => { - // We test the formatting logic by importing the private helpers. - // Since cli.js is a script, we extract the formatting to test it directly. - - function formatOutput(texts, embeddings, format) { - switch (format) { - case 'txt': - return embeddings.map((vec) => vec.join(' ')).join('\n'); - case 'sql': { - const rows = texts.map((text, i) => { - const safeText = text.replace(/'/g, "''"); - const vector = JSON.stringify(embeddings[i]); - return ` ('${safeText}', '${vector}')`; - }); - return ( - 'INSERT INTO embeddings (text, vector) VALUES\n' + - rows.join(',\n') + - ';' - ); +// ── Inline helpers mirroring cli.js (cli.js runs main() on import) ────────── + +function parseDelimiter(str) { + return str + .replace(/\\0/g, '\0') + .replace(/\\n/g, '\n') + .replace(/\\t/g, '\t') + .replace(/\\r/g, '\r'); +} + +function parseTexts(raw, delimiter = '\n') { + try { + const parsed = JSON.parse(raw); + if (!Array.isArray(parsed)) throw new Error('Expected a JSON array'); + return parsed; + } catch { + return raw.split(delimiter).filter(Boolean); + } +} + +function formatOutput(texts, embeddings, format, withText = false) { + switch (format) { + case 'jsonl': + return texts + .map((text, i) => JSON.stringify({ text, embedding: embeddings[i] })) + .join('\n'); + + case 'csv': { + if (embeddings.length === 0) return ''; + const dims = embeddings[0].length; + const header = ['text', ...Array.from({ length: dims }, (_, k) => `dim_${k}`)].join(','); + const rows = texts.map((text, i) => { + const safeText = '"' + text.replace(/"/g, '""') + '"'; + return [safeText, ...embeddings[i]].join(','); + }); + return [header, ...rows].join('\n'); + } + + case 'txt': + if (withText) { + return texts.map((text, i) => `${text}\t${embeddings[i].join(' ')}`).join('\n'); } - default: - return JSON.stringify(embeddings); + return embeddings.map((vec) => vec.join(' ')).join('\n'); + + case 'sql': { + const rows = texts.map((text, i) => { + const safeText = text.replace(/'/g, "''"); + const vector = JSON.stringify(embeddings[i]); + return ` ('${safeText}', '${vector}')`; + }); + return ( + 'INSERT INTO embeddings (text, vector) VALUES\n' + + rows.join(',\n') + + ';' + ); } + + default: // json + if (withText) { + return JSON.stringify( + texts.map((text, i) => ({ text, embedding: embeddings[i] })) + ); + } + return JSON.stringify(embeddings); } +} + +// ── parseDelimiter ─────────────────────────────────────────────────────────── + +describe('parseDelimiter', () => { + test('leaves plain string unchanged', () => { + assert.equal(parseDelimiter('|||'), '|||'); + }); + + test('\\0 becomes null byte', () => { + assert.equal(parseDelimiter('\\0'), '\0'); + }); + + test('\\n becomes newline', () => { + assert.equal(parseDelimiter('\\n'), '\n'); + }); + + test('\\t becomes tab', () => { + assert.equal(parseDelimiter('\\t'), '\t'); + }); + + test('\\r becomes carriage return', () => { + assert.equal(parseDelimiter('\\r'), '\r'); + }); + + test('multiple escape sequences in one string', () => { + assert.equal(parseDelimiter('\\r\\n'), '\r\n'); + }); +}); + +// ── parseTexts ─────────────────────────────────────────────────────────────── + +describe('parseTexts', () => { + test('JSON array is parsed directly', () => { + const result = parseTexts('["a","b","c"]'); + assert.deepEqual(result, ['a', 'b', 'c']); + }); + + test('defaults to newline delimiter', () => { + const result = parseTexts('foo\nbar\nbaz'); + assert.deepEqual(result, ['foo', 'bar', 'baz']); + }); + test('custom delimiter (pipe)', () => { + const result = parseTexts('foo|bar|baz', '|'); + assert.deepEqual(result, ['foo', 'bar', 'baz']); + }); + + test('custom delimiter (null byte)', () => { + const result = parseTexts('foo\0bar\0baz', '\0'); + assert.deepEqual(result, ['foo', 'bar', 'baz']); + }); + + test('custom delimiter (tab)', () => { + const result = parseTexts('foo\tbar\tbaz', '\t'); + assert.deepEqual(result, ['foo', 'bar', 'baz']); + }); + + test('filters empty strings after split', () => { + const result = parseTexts('\nfoo\n\nbar\n', '\n'); + assert.deepEqual(result, ['foo', 'bar']); + }); + + test('JSON array takes precedence over delimiter parsing', () => { + const result = parseTexts('["x","y"]', '|'); + assert.deepEqual(result, ['x', 'y']); + }); +}); + +// ── CLI output formatting ──────────────────────────────────────────────────── + +describe('CLI output formatting', () => { const texts = ['Hello world', "It's a test"]; const embeddings = [[0.1, 0.2], [0.3, 0.4]]; - test('json output is valid JSON array', () => { + test('json output is valid JSON array of vectors', () => { const out = formatOutput(texts, embeddings, 'json'); const parsed = JSON.parse(out); assert.deepEqual(parsed, embeddings); }); + test('json --with-text wraps each item with text field', () => { + const out = formatOutput(texts, embeddings, 'json', true); + const parsed = JSON.parse(out); + assert.equal(parsed.length, 2); + assert.equal(parsed[0].text, 'Hello world'); + assert.deepEqual(parsed[0].embedding, [0.1, 0.2]); + assert.equal(parsed[1].text, "It's a test"); + assert.deepEqual(parsed[1].embedding, [0.3, 0.4]); + }); + + test('jsonl produces one JSON object per line', () => { + const out = formatOutput(texts, embeddings, 'jsonl'); + const lines = out.split('\n'); + assert.equal(lines.length, 2); + const first = JSON.parse(lines[0]); + assert.equal(first.text, 'Hello world'); + assert.deepEqual(first.embedding, [0.1, 0.2]); + const second = JSON.parse(lines[1]); + assert.equal(second.text, "It's a test"); + assert.deepEqual(second.embedding, [0.3, 0.4]); + }); + + test('csv produces header row and data rows', () => { + const out = formatOutput(texts, embeddings, 'csv'); + const lines = out.split('\n'); + assert.equal(lines[0], 'text,dim_0,dim_1'); + assert.equal(lines[1], '"Hello world",0.1,0.2'); + assert.equal(lines[2], '"It\'s a test",0.3,0.4'); + }); + + test('csv escapes double-quotes in text', () => { + const out = formatOutput(['say "hi"'], [[1, 2]], 'csv'); + const lines = out.split('\n'); + assert.equal(lines[1], '"say ""hi""",1,2'); + }); + + test('csv returns empty string for zero embeddings', () => { + assert.equal(formatOutput([], [], 'csv'), ''); + }); + test('txt output is one space-separated line per embedding', () => { const out = formatOutput(texts, embeddings, 'txt'); const lines = out.split('\n'); @@ -72,6 +224,13 @@ describe('CLI output formatting', async () => { assert.equal(lines[1], '0.3 0.4'); }); + test('txt --with-text prefixes each line with text and tab', () => { + const out = formatOutput(texts, embeddings, 'txt', true); + const lines = out.split('\n'); + assert.equal(lines[0], 'Hello world\t0.1 0.2'); + assert.equal(lines[1], "It's a test\t0.3 0.4"); + }); + test('sql output starts with INSERT and contains both rows', () => { const out = formatOutput(texts, embeddings, 'sql'); assert.ok(out.startsWith('INSERT INTO embeddings')); diff --git a/test/provider-loader.test.js b/test/provider-loader.test.js new file mode 100644 index 0000000..589085f --- /dev/null +++ b/test/provider-loader.test.js @@ -0,0 +1,257 @@ +/** + * Unit tests for provider-loader.js + * + * Tests verify provider selection logic and error messages when a GPU provider + * is unavailable or unsupported on the current platform. + * + * All tests use process.platform/arch overrides to isolate platform logic + * without requiring real GPU hardware. + */ + +import { test, describe } from 'node:test'; +import assert from 'node:assert/strict'; +import { + getPlatformDefaultProviders, + tryLoadProvider, + resolveProvider, +} from '../src/provider-loader.js'; + +// ── Helpers ────────────────────────────────────────────────────────────────── + +/** + * Temporarily override process.platform and process.arch, restore after fn(). + */ +async function withPlatform(platform, arch, fn) { + const origPlatform = Object.getOwnPropertyDescriptor(process, 'platform'); + const origArch = Object.getOwnPropertyDescriptor(process, 'arch'); + Object.defineProperty(process, 'platform', { value: platform, configurable: true }); + Object.defineProperty(process, 'arch', { value: arch, configurable: true }); + try { + await fn(); + } finally { + if (origPlatform) Object.defineProperty(process, 'platform', origPlatform); + if (origArch) Object.defineProperty(process, 'arch', origArch); + } +} + +// ── getPlatformDefaultProviders() ──────────────────────────────────────────── + +describe('getPlatformDefaultProviders()', () => { + test('returns [cuda] on linux/x64', async () => { + await withPlatform('linux', 'x64', () => { + assert.deepEqual(getPlatformDefaultProviders(), ['cuda']); + }); + }); + + test('returns [cuda, dml] on win32/x64 (CUDA preferred over DML)', async () => { + await withPlatform('win32', 'x64', () => { + assert.deepEqual(getPlatformDefaultProviders(), ['cuda', 'dml']); + }); + }); + + test('returns [] on unsupported platforms (e.g. darwin/arm64)', async () => { + await withPlatform('darwin', 'arm64', () => { + assert.deepEqual(getPlatformDefaultProviders(), []); + }); + }); +}); + +// ── tryLoadProvider() ──────────────────────────────────────────────────────── + +describe('tryLoadProvider()', () => { + test('returns { loaded: false } when provider is not supported on platform', async () => { + await withPlatform('darwin', 'arm64', async () => { + const result = await tryLoadProvider('cuda'); + assert.equal(result.loaded, false); + assert.equal(result.deviceStr, null); + }); + }); + + test('returns { loaded: false } when GPU hardware or system libs are missing', async () => { + // In a typical CI environment there is no NVIDIA GPU, so activateCuda() + // throws when /dev/nvidiactl is missing. tryLoadProvider must catch it + // and return { loaded: false }. + await withPlatform('linux', 'x64', async () => { + const result = await tryLoadProvider('cuda'); + assert.equal(result.loaded, false); + assert.equal(result.deviceStr, null); + // error may be set (GPU not found) or null (provider not implemented) + }); + }); +}); + +// ── resolveProvider() ──────────────────────────────────────────────────────── + +describe('resolveProvider()', () => { + // ── CPU paths ───────────────────────────────────────────────────────────── + + test('returns undefined when device=cpu', async () => { + const result = await resolveProvider('cpu', undefined); + assert.equal(result, undefined); + }); + + test('returns undefined when provider=cpu', async () => { + const result = await resolveProvider('auto', 'cpu'); + assert.equal(result, undefined); + }); + + test('returns undefined when device and provider are both undefined', async () => { + const result = await resolveProvider(undefined, undefined); + assert.equal(result, undefined); + }); + + // ── device=auto with no GPU available ──────────────────────────────────── + + test('device=auto returns undefined (CPU fallback) when GPU provider fails to activate', async () => { + await withPlatform('linux', 'x64', async () => { + // No NVIDIA GPU in CI; device='auto' must silently fall back to CPU. + const result = await resolveProvider('auto', undefined); + assert.equal(result, undefined); + }); + }); + + test('device=auto returns undefined on unsupported platform (no GPU providers)', async () => { + await withPlatform('darwin', 'arm64', async () => { + const result = await resolveProvider('auto', undefined); + assert.equal(result, undefined); + }); + }); + + // ── device=gpu with no GPU available ───────────────────────────────────── + + test('device=gpu throws with GPU-related error when no GPU available (linux/x64)', async () => { + await withPlatform('linux', 'x64', async () => { + // No NVIDIA GPU in CI; resolveProvider should throw with a diagnostic + // message about the GPU or CUDA requirements. + await assert.rejects( + () => resolveProvider('gpu', undefined), + (err) => { + assert.ok( + err.message.toLowerCase().includes('nvidia') || + err.message.toLowerCase().includes('cuda') || + err.message.toLowerCase().includes('gpu'), + `Expected GPU-related context in error, got: ${err.message}`, + ); + return true; + }, + ); + }); + }); + + test('device=gpu throws on unsupported platform with informative message', async () => { + await withPlatform('darwin', 'arm64', async () => { + await assert.rejects( + () => resolveProvider('gpu', undefined), + (err) => { + assert.ok( + err.message.includes("device='gpu'"), + `Expected GPU error message, got: ${err.message}`, + ); + return true; + }, + ); + }); + }); + + // ── explicit provider not available ────────────────────────────────────── + + test('explicit provider=cuda throws with diagnostic error when GPU hardware is missing', async () => { + await withPlatform('linux', 'x64', async () => { + // No NVIDIA GPU in CI; activate() throws a diagnostic error about the + // missing hardware or CUDA libraries. resolveProvider re-throws it. + await assert.rejects( + () => resolveProvider('cpu', 'cuda'), + (err) => { + assert.ok( + err.message.toLowerCase().includes('nvidia') || + err.message.toLowerCase().includes('cuda') || + err.message.toLowerCase().includes('gpu'), + `Expected GPU-related context in error, got: ${err.message}`, + ); + return true; + }, + ); + }); + }); + + test('explicit provider=dml succeeds on win32 when platform is Windows', async () => { + await withPlatform('win32', 'x64', async () => { + // activateDml() checks process.platform === 'win32' (mocked here) and succeeds. + const result = await resolveProvider('cpu', 'dml'); + assert.equal(result, 'dml'); + }); + }); + + // ── unsupported provider on platform ────────────────────────────────────── + + test('explicit provider=dml throws "not supported" on linux', async () => { + await withPlatform('linux', 'x64', async () => { + await assert.rejects( + () => resolveProvider('gpu', 'dml'), + (err) => { + assert.ok( + err.message.toLowerCase().includes('not supported'), + `Expected "not supported" in error, got: ${err.message}`, + ); + return true; + }, + ); + }); + }); + + test('explicit provider=cuda throws "not supported" on darwin/arm64', async () => { + await withPlatform('darwin', 'arm64', async () => { + await assert.rejects( + () => resolveProvider('gpu', 'cuda'), + (err) => { + assert.ok( + err.message.toLowerCase().includes('not supported'), + `Expected "not supported" in error, got: ${err.message}`, + ); + return true; + }, + ); + }); + }); +}); + +// ── WorkerPool device/provider options ─────────────────────────────────────── + +describe('WorkerPool — device and provider options', async () => { + const { WorkerPool } = await import('../src/worker-pool.js'); + const { EventEmitter } = await import('events'); + + class SpyWorker extends EventEmitter { + constructor(scriptPath, opts) { + super(); + SpyWorker.lastOpts = opts; + setImmediate(() => this.emit('message', { type: 'ready' })); + } + postMessage() {} + async terminate() { setImmediate(() => this.emit('exit', 0)); } + } + + test('device and provider are stored in WorkerPool', () => { + const pool = new WorkerPool('model', { + _WorkerClass: SpyWorker, + device: 'gpu', + provider: 'cuda', + }); + assert.equal(pool.device, 'gpu'); + assert.equal(pool.provider, 'cuda'); + }); + + test('workerData includes device and provider', async () => { + const pool = new WorkerPool('model', { + _WorkerClass: SpyWorker, + poolSize: 1, + device: 'auto', + provider: 'cuda', + }); + await pool.initialize(); + const wd = SpyWorker.lastOpts.workerData; + assert.equal(wd.device, 'auto'); + assert.equal(wd.provider, 'cuda'); + await pool.destroy(); + }); +});