diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7b672f5..5d5c0a9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,143 +1,23 @@ +# CI — thin wrapper that calls the reusable test workflow. + name: CI on: pull_request: branches: [main] + types: [opened, synchronize, reopened, ready_for_review] + workflow_dispatch: + workflow_call: -env: - CARGO_TERM_COLOR: always - -jobs: - # Lint and format check (single job) - lint: - name: Lint & Format - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Install Rust - uses: dtolnay/rust-toolchain@stable - with: - components: rustfmt, clippy +concurrency: + group: ci-${{ github.ref }} + cancel-in-progress: true - - name: Install PCRE2 - run: sudo apt-get update && sudo apt-get install -y libpcre2-dev +permissions: + contents: read - - name: Check formatting - run: cargo fmt --all --check - - - name: Run clippy - run: cargo clippy --all-targets -- -D warnings - - # Rust tests on multiple platforms +jobs: test: - name: Test (${{ matrix.os }}) - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest, macos-latest, windows-latest] - - steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.12' - - - name: Install Rust - uses: dtolnay/rust-toolchain@stable - - - name: Install PCRE2 (Ubuntu) - if: matrix.os == 'ubuntu-latest' - run: sudo apt-get update && sudo apt-get install -y libpcre2-dev - - - name: Configure Python for PyO3 (Ubuntu) - if: matrix.os == 'ubuntu-latest' - run: echo "PYO3_PYTHON=$(which python3)" >> $GITHUB_ENV - - - name: Install PCRE2 (macOS) - if: matrix.os == 'macos-latest' - run: brew install pcre2 - - - name: Configure Python for PyO3 (macOS) - if: matrix.os == 'macos-latest' - run: | - echo "PYO3_PYTHON=$(which python3)" >> $GITHUB_ENV - # Get Python library directory and set for linker - PYTHON_PREFIX=$(python3 -c "import sys; print(sys.prefix)") - echo "LIBRARY_PATH=${PYTHON_PREFIX}/lib" >> $GITHUB_ENV - echo "DYLD_LIBRARY_PATH=${PYTHON_PREFIX}/lib" >> $GITHUB_ENV - # Tell Cargo to link the Python framework - echo "CARGO_BUILD_RUSTFLAGS=-C link-arg=-undefined -C link-arg=dynamic_lookup" >> $GITHUB_ENV - - - name: Install PCRE2 (Windows) - if: matrix.os == 'windows-latest' - run: | - vcpkg install pcre2:x64-windows - echo "PCRE2_SYS_STATIC=1" >> $env:GITHUB_ENV - - - name: Configure Python for PyO3 (Windows) - if: matrix.os == 'windows-latest' - run: | - $pythonPath = (Get-Command python).Source - echo "PYO3_PYTHON=$pythonPath" >> $env:GITHUB_ENV - - - name: Run tests - run: cargo test - - # Python bindings tests - python: - name: Python tests - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.12' - - - name: Install PCRE2 - run: sudo apt-get update && sudo apt-get install -y libpcre2-dev - - - name: Install dependencies and build - run: | - python -m venv .venv - . .venv/bin/activate - python -m pip install --upgrade pip - pip install maturin tiktoken - maturin develop --release - - - name: Test Python bindings - run: | - .venv/bin/python -c " - import splintr - import tiktoken - - # Test cl100k_base - tok = splintr.Tokenizer.from_pretrained('cl100k_base') - tik = tiktoken.get_encoding('cl100k_base') - - text = 'Hello, world!' - assert tok.encode(text) == list(tik.encode(text)), 'cl100k_base mismatch' - - # Test o200k_base - tok2 = splintr.Tokenizer.from_pretrained('o200k_base') - tik2 = tiktoken.get_encoding('o200k_base') - assert tok2.encode(text) == list(tik2.encode(text)), 'o200k_base mismatch' - - # Test streaming decoder - decoder = tok.streaming_decoder() - tokens = tok.encode('Hello') - result = [] - for t in tokens: - chunk = decoder.add_token(t) - if chunk: - result.append(chunk) - result.append(decoder.flush()) - assert ''.join(result) == 'Hello', 'Streaming decoder failed' - - print('All Python tests passed!') - " + if: github.event.pull_request.draft == false + name: Test Suite + uses: ./.github/workflows/test.yml diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 45243c7..88bd5ed 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,9 +1,10 @@ name: Release +run-name: Release ${{ github.ref_name }} on: push: tags: - - 'v*' + - "v*" permissions: contents: read @@ -18,7 +19,7 @@ jobs: pypi_version: ${{ steps.version.outputs.pypi_version }} base_version: ${{ steps.version.outputs.base_version }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Validate and extract version id: version @@ -78,7 +79,7 @@ jobs: needs: validate-version runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install Rust uses: dtolnay/rust-toolchain@stable @@ -102,7 +103,7 @@ jobs: run: cargo publish --allow-dirty --token ${{ secrets.CARGO_REGISTRY_TOKEN }} env: # Enable PCRE2 JIT compilation - PCRE2_SYS_JIT: '1' + PCRE2_SYS_JIT: "1" # Build Python wheels for multiple platforms build-wheels: @@ -115,12 +116,12 @@ jobs: os: [ubuntu-latest, macos-15-intel, macos-14, windows-latest] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.12' + python-version: "3.12" - name: Update version in pyproject.toml shell: bash @@ -156,11 +157,11 @@ jobs: # Build with python feature (PyO3 bindings) + pcre2 (PCRE2 backend with JIT) # Note: regexr's SIMD uses runtime detection, JIT is compiled at build time args: --release --out dist --features python,pcre2 - sccache: 'true' + sccache: "true" manylinux: auto env: # Enable PCRE2 JIT compilation - PCRE2_SYS_JIT: '1' + PCRE2_SYS_JIT: "1" - name: Upload wheels uses: actions/upload-artifact@v4 @@ -174,7 +175,7 @@ jobs: needs: validate-version runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Update version in pyproject.toml run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..a891be5 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,163 @@ +# Reusable test workflow: lint, check, and test. +# +# Called by: +# - ci.yml (PR checks) +# - release.yml (pre-publish gate) + +name: Test + +on: + workflow_call: + +permissions: + contents: read + +env: + CARGO_TERM_COLOR: always + +jobs: + lint: + name: Lint & Format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt, clippy + + - name: Install PCRE2 + run: sudo apt-get update && sudo apt-get install -y libpcre2-dev + + - uses: Swatinem/rust-cache@v2 + with: + prefix-key: lint + + - name: Check formatting + run: cargo fmt --all --check + + - name: Run clippy + run: cargo clippy --all-targets -- -D warnings + + test: + name: Test (${{ matrix.target }}) + runs-on: ${{ matrix.runs-on }} + strategy: + fail-fast: false + matrix: + include: + - runs-on: ubuntu-latest + target: x86_64-unknown-linux-gnu + - runs-on: macos-latest + target: aarch64-apple-darwin + - runs-on: windows-latest + target: x86_64-pc-windows-msvc + steps: + - uses: actions/checkout@v5 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - uses: Swatinem/rust-cache@v2 + with: + prefix-key: test-${{ matrix.target }} + + - name: Install PCRE2 (Ubuntu) + if: runner.os == 'Linux' + run: sudo apt-get update && sudo apt-get install -y libpcre2-dev + + - name: Configure Python for PyO3 (Ubuntu) + if: runner.os == 'Linux' + run: echo "PYO3_PYTHON=$(which python3)" >> $GITHUB_ENV + + - name: Install PCRE2 (macOS) + if: runner.os == 'macOS' + run: brew install pcre2 + + - name: Configure Python for PyO3 (macOS) + if: runner.os == 'macOS' + run: | + echo "PYO3_PYTHON=$(which python3)" >> $GITHUB_ENV + PYTHON_PREFIX=$(python3 -c "import sys; print(sys.prefix)") + echo "LIBRARY_PATH=${PYTHON_PREFIX}/lib" >> $GITHUB_ENV + echo "DYLD_LIBRARY_PATH=${PYTHON_PREFIX}/lib" >> $GITHUB_ENV + echo "CARGO_BUILD_RUSTFLAGS=-C link-arg=-undefined -C link-arg=dynamic_lookup" >> $GITHUB_ENV + + - name: Install PCRE2 (Windows) + if: runner.os == 'Windows' + run: | + vcpkg install pcre2:x64-windows + echo "PCRE2_SYS_STATIC=1" >> $env:GITHUB_ENV + + - name: Configure Python for PyO3 (Windows) + if: runner.os == 'Windows' + run: | + $pythonPath = (Get-Command python).Source + echo "PYO3_PYTHON=$pythonPath" >> $env:GITHUB_ENV + + - name: Run tests + run: cargo test + + python: + name: Python bindings + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install PCRE2 + run: sudo apt-get update && sudo apt-get install -y libpcre2-dev + + - uses: Swatinem/rust-cache@v2 + with: + prefix-key: python + + - name: Install dependencies and build + run: | + python -m venv .venv + . .venv/bin/activate + python -m pip install --upgrade pip + pip install maturin tiktoken + maturin develop --release + + - name: Test Python bindings + run: | + .venv/bin/python -c " + import splintr + import tiktoken + + # Test cl100k_base + tok = splintr.Tokenizer.from_pretrained('cl100k_base') + tik = tiktoken.get_encoding('cl100k_base') + + text = 'Hello, world!' + assert tok.encode(text) == list(tik.encode(text)), 'cl100k_base mismatch' + + # Test o200k_base + tok2 = splintr.Tokenizer.from_pretrained('o200k_base') + tik2 = tiktoken.get_encoding('o200k_base') + assert tok2.encode(text) == list(tik2.encode(text)), 'o200k_base mismatch' + + # Test streaming decoder + decoder = tok.streaming_decoder() + tokens = tok.encode('Hello') + result = [] + for t in tokens: + chunk = decoder.add_token(t) + if chunk: + result.append(chunk) + result.append(decoder.flush()) + assert ''.join(result) == 'Hello', 'Streaming decoder failed' + + print('All Python tests passed!') + " diff --git a/.version b/.version index ac39a10..f374f66 100644 --- a/.version +++ b/.version @@ -1 +1 @@ -0.9.0 +0.9.1 diff --git a/Cargo.toml b/Cargo.toml index 464c56c..341e456 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,13 @@ [package] name = "splintr" -version = "0.9.0" +version = "0.9.1" edition = "2021" -description = "Fast Rust tokenizer (BPE + SentencePiece) with Python bindings" +description = "Fast Rust tokenizer (BPE + SentencePiece + WordPiece) with Python bindings" license = "MIT" repository = "https://github.com/ml-rust/splintr" homepage = "https://github.com/ml-rust/splintr" readme = "README.md" -keywords = ["tokenizer", "bpe", "sentencepiece", "tiktoken", "llm"] +keywords = ["tokenizer", "bpe", "sentencepiece", "wordpiece", "llm"] categories = ["text-processing", "encoding"] [lib] @@ -20,7 +20,7 @@ python = ["dep:pyo3"] pcre2 = ["dep:pcre2"] rayon = ["dep:rayon"] regexr-jit = ["regexr/jit", "regexr/simd"] -wasm = [] # disables rayon, uses scalar regex — use with --no-default-features +wasm = [] # disables rayon, uses scalar regex — use with --no-default-features [dependencies] # PCRE2 regex with JIT support (optional, for benchmarking) @@ -40,7 +40,11 @@ aho-corasick = "1.1" # LRU cache for frequent token sequences lru = "0.16" # regexr regex engine (default backend) -regexr = { version = "0.1.0-beta.5", default-features = false } +regexr = { version = "0.1", default-features = false } +# Unicode normalization for WordPiece accent stripping +unicode-normalization = "0.1" +# Unicode general category for punctuation detection +unicode-general-category = "1.0" [dev-dependencies] # PCRE2 for benchmarking comparisons diff --git a/README.md b/README.md index 06c12cf..cf08643 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Crates.io](https://img.shields.io/crates/v/splintr.svg)](https://crates.io/crates/splintr) [![PyPI](https://img.shields.io/pypi/v/splintr-rs.svg)](https://pypi.org/project/splintr-rs/) [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) -**A high-performance tokenizer (BPE + SentencePiece) built with Rust with Python bindings, focused on speed, safety, and resource optimization.** +**A high-performance tokenizer (BPE + SentencePiece + WordPiece) built with Rust with Python bindings, focused on speed, safety, and resource optimization.** ## The Problem @@ -85,7 +85,7 @@ See the [API Guide](docs/api_guide.md) and [docs.rs](https://docs.rs/splintr) fo - **Compatible vocabularies** - Supports cl100k_base, o200k_base (OpenAI), Llama 3 family (Meta), DeepSeek V3 (DeepSeek), and Mistral V1/V2/V3 (Mistral AI) - **Streaming decoders** - Real-time LLM output display with proper UTF-8 handling ([guide](docs/api_guide.md#streaming-decoder)) - **54 agent tokens** - Built-in support for chat, CoT reasoning, ReAct agents, tool calling, RAG citations ([docs](docs/special_tokens.md)) -- **Battle-tested algorithms** - Regexr with JIT (pure Rust), Aho-Corasick for special tokens, linked-list BPE, SentencePiece unigram +- **Battle-tested algorithms** - Regexr with JIT (pure Rust), Aho-Corasick for special tokens, linked-list BPE, SentencePiece unigram, WordPiece for BERT-family models **Cross-platform:** @@ -280,6 +280,7 @@ Splintr implements several optimizations that make tokenization faster: - **Rayon parallelism**: Leverages multiple CPU cores for batch encoding - **Linked-list BPE algorithm**: Avoids O(N²) complexity on pathological inputs - **SentencePiece unigram**: Greedy longest-match with score-based tie-breaking for Mistral/Llama-style models +- **WordPiece tokenizer**: BERT-compatible subword tokenization with `##` continuation prefix, BasicTokenizer preprocessing (lowercase, accent stripping, punctuation splitting) - **FxHashMap**: Faster lookups than default SipHash for non-adversarial contexts - **Aho-Corasick for special tokens**: Fast multi-pattern matching without regex alternation - **LRU cache**: Avoids redundant BPE encoding of frequently seen chunks @@ -370,7 +371,7 @@ If you use Splintr in your research, please cite: ```bibtex @software{splintr, author = {Farhan Syah}, - title = {Splintr: High-Performance Tokenizer (BPE + SentencePiece)}, + title = {Splintr: High-Performance Tokenizer (BPE + SentencePiece + WordPiece)}, year = {2025}, url = {https://github.com/ml-rust/splintr} } diff --git a/docs/api_guide.md b/docs/api_guide.md index 8754477..168cc38 100644 --- a/docs/api_guide.md +++ b/docs/api_guide.md @@ -14,8 +14,10 @@ This guide provides comprehensive documentation for using Splintr's Python and R - [Regular Streaming Decoder](#regular-streaming-decoder) - [ByteLevel Streaming Decoder](#bytelevel-streaming-decoder) - [Rust API Reference](#rust-api-reference) + - [Tokenize Trait](#tokenize-trait) - [BPE Tokenizer](#bpe-tokenizer) - [SentencePiece Tokenizer](#sentencepiece-tokenizer) + - [WordPiece Tokenizer](#wordpiece-tokenizer) - [Detailed Usage Examples](#detailed-usage-examples) - [Basic Encoding and Decoding](#basic-encoding-and-decoding) - [Batch Processing](#batch-processing) @@ -344,6 +346,26 @@ Add Splintr to your `Cargo.toml`: splintr = "*" # or pin to a specific version ``` +### Tokenize Trait + +All tokenizer backends implement the `Tokenize` trait, enabling generic code: + +```rust +use splintr::Tokenize; + +fn count_tokens(tokenizer: &dyn Tokenize, text: &str) -> usize { + tokenizer.encode(text).len() +} +``` + +**Methods:** + +- `encode(&self, text: &str) -> Vec`: Encode text to token IDs +- `decode(&self, ids: &[u32]) -> Result`: Decode token IDs to text +- `vocab_size(&self) -> usize`: Vocabulary size + +Implemented by `Tokenizer` (BPE), `SentencePieceTokenizer` (unigram), and `WordPieceTokenizer` (WordPiece). + ### BPE Tokenizer ```rust @@ -412,6 +434,43 @@ let text = tokenizer.decode_lossy(&ids); - `eos_token_id(&self) -> u32`: Get EOS token ID - `bos_token_id(&self) -> Option`: Get BOS token ID +### WordPiece Tokenizer + +For BERT-family models using WordPiece subword tokenization: + +```rust +use splintr::{WordPieceTokenizer, Tokenize}; + +// Create from a flat vocabulary (index = token ID) +let vocab = vec![ + "[PAD]", "[UNK]", "[CLS]", "[SEP]", + "hello", "world", "##ing", "##s", +].into_iter().map(String::from).collect(); + +let tokenizer = WordPieceTokenizer::new( + vocab, // Vec — token strings indexed by ID + 1, // UNK token ID + 200, // Max word length before mapping to UNK + true, // Lowercase and strip accents (for uncased models) +); + +// Encode (BasicTokenizer + WordPiece greedy longest-match) +let ids = tokenizer.encode("Hello world"); + +// Decode (reconstructs text, skips [CLS]/[SEP]/[PAD] special tokens) +let text = tokenizer.decode(&ids)?; +``` + +#### Methods + +- `encode(&self, text: &str) -> Vec`: BasicTokenizer + WordPiece subword tokenization +- `decode(&self, ids: &[u32]) -> Result`: Decode, joining subwords and removing `##` prefixes +- `vocab_size(&self) -> usize`: Vocabulary size +- `cls_token_id(&self) -> Option`: `[CLS]` token ID +- `sep_token_id(&self) -> Option`: `[SEP]` token ID +- `pad_token_id(&self) -> Option`: `[PAD]` token ID +- `unk_token_id(&self) -> u32`: `[UNK]` token ID + ### Error Handling The Rust API uses `Result` types for operations that can fail: diff --git a/docs/special_tokens.md b/docs/special_tokens.md index 46514eb..28ea455 100644 --- a/docs/special_tokens.md +++ b/docs/special_tokens.md @@ -82,15 +82,15 @@ This convention mirrors XML/HTML for familiarity while using `<|...|>` to avoid Token IDs are carefully allocated to avoid conflicts with reserved ranges: -| Model | Regular Tokens | Reserved Range | Agent Tokens | Total | -| ------------- | -------------- | --------------- | ----------------- | --------- | -| `cl100k_base` | 0-100,255 | 100,257-100,276 | 100,277-100,330 | 100,331 | -| `o200k_base` | 0-199,997 | 199,999-200,018 | 200,019-200,072 | 200,073 | -| `llama3` | 0-127,999 | 128,000-128,261 | 128,300-128,353 | 128,354 | -| `deepseek_v3` | 0-127,999 | 128,798-128,814 | 128,900-128,953 | 128,954 | -| `mistral_v1` | 0-31,999 | 0-2 | 32,000-32,053 | 32,054 | -| `mistral_v2` | 0-32,767 | 0-9 | 32,768-32,821 | 32,822 | -| `mistral_v3` | 0-131,071 | 0-9 | 131,072-131,125 | 131,126 | +| Model | Regular Tokens | Reserved Range | Agent Tokens | Total | +| ------------- | -------------- | --------------- | --------------- | ------- | +| `cl100k_base` | 0-100,255 | 100,257-100,276 | 100,277-100,330 | 100,331 | +| `o200k_base` | 0-199,997 | 199,999-200,018 | 200,019-200,072 | 200,073 | +| `llama3` | 0-127,999 | 128,000-128,261 | 128,300-128,353 | 128,354 | +| `deepseek_v3` | 0-127,999 | 128,798-128,814 | 128,900-128,953 | 128,954 | +| `mistral_v1` | 0-31,999 | 0-2 | 32,000-32,053 | 32,054 | +| `mistral_v2` | 0-32,767 | 0-9 | 32,768-32,821 | 32,822 | +| `mistral_v3` | 0-131,071 | 0-9 | 131,072-131,125 | 131,126 | ### Why These Ranges? @@ -192,34 +192,34 @@ Splintr's `deepseek_v3` vocabulary includes the base 128,000 BPE tokens plus all #### Core Tokens -| Token | ID | Purpose | -| ------------------ | ------ | ------------------------- | -| `<\|begin▁of▁sentence\|>` | 0 | Beginning of sequence (BOS) | -| `<\|end▁of▁sentence\|>` | 1 | End of sequence (EOS) | +| Token | ID | Purpose | +| ------------------------- | --- | --------------------------- | +| `<\|begin▁of▁sentence\|>` | 0 | Beginning of sequence (BOS) | +| `<\|end▁of▁sentence\|>` | 1 | End of sequence (EOS) | **Note**: DeepSeek uses the special underscore character `▁` (U+2581, "LOWER ONE EIGHTH BLOCK") in token names, which differs from the regular underscore `_`. #### Reasoning Tokens (Native DeepSeek) -| Token | ID | Purpose | -| ------------------ | ------ | ------------------------- | -| `<\|begin▁of▁thought\|>` | 128798 | Start of thinking block | -| `<\|end▁of▁thought\|>` | 128799 | End of thinking block | +| Token | ID | Purpose | +| ------------------------ | ------ | ----------------------- | +| `<\|begin▁of▁thought\|>` | 128798 | Start of thinking block | +| `<\|end▁of▁thought\|>` | 128799 | End of thinking block | #### Fill-in-the-Middle (FIM) Tokens -| Token | ID | Purpose | -| ------------------- | ------ | ------------------------------ | -| `<\|fim▁begin\|>` | 128800 | Start of FIM context | -| `<\|fim▁hole\|>` | 128801 | Placeholder for code to insert | -| `<\|fim▁end\|>` | 128802 | End of FIM context | +| Token | ID | Purpose | +| ----------------- | ------ | ------------------------------ | +| `<\|fim▁begin\|>` | 128800 | Start of FIM context | +| `<\|fim▁hole\|>` | 128801 | Placeholder for code to insert | +| `<\|fim▁end\|>` | 128802 | End of FIM context | #### Chat Tokens -| Token | ID | Purpose | -| ----------------- | ------ | ------------------- | -| `<\|User\|>` | 128803 | User turn marker | -| `<\|Assistant\|>` | 128804 | Assistant turn marker | +| Token | ID | Purpose | +| ------------------- | ------ | ------------------------ | +| `<\|User\|>` | 128803 | User turn marker | +| `<\|Assistant\|>` | 128804 | Assistant turn marker | | `<\|end▁of▁turn\|>` | 128805 | End of conversation turn | #### Tool/Function Calling Tokens @@ -284,11 +284,11 @@ Splintr's `mistral_v1` vocabulary includes ~32,000 BPE tokens plus 54 agent toke #### Core SentencePiece Tokens -| Token | ID | Purpose | -| ------- | -- | ---------------------------- | -| `` | 0 | Unknown token | -| `` | 1 | Beginning of sequence (BOS) | -| `` | 2 | End of sequence (EOS) | +| Token | ID | Purpose | +| ------- | --- | --------------------------- | +| `` | 0 | Unknown token | +| `` | 1 | Beginning of sequence (BOS) | +| `` | 2 | End of sequence (EOS) | ### mistral_v2 (Mistral 7B v0.3, Codestral, Mixtral 8x22B) @@ -296,23 +296,23 @@ V2 extends V1 with 768 control tokens for tool calling and instruction formattin #### Core SentencePiece Tokens (same as V1) -| Token | ID | Purpose | -| ------- | -- | ---------------------------- | -| `` | 0 | Unknown token | -| `` | 1 | Beginning of sequence (BOS) | -| `` | 2 | End of sequence (EOS) | +| Token | ID | Purpose | +| ------- | --- | --------------------------- | +| `` | 0 | Unknown token | +| `` | 1 | Beginning of sequence (BOS) | +| `` | 2 | End of sequence (EOS) | #### V2 Control Tokens -| Token | ID | Purpose | -| -------------------- | -- | -------------------------- | -| `[INST]` | 3 | Start of user instruction | -| `[/INST]` | 4 | End of user instruction | -| `[TOOL_CALLS]` | 5 | Tool calling block | -| `[AVAILABLE_TOOLS]` | 6 | Available tools definition | -| `[/AVAILABLE_TOOLS]` | 7 | End of tools definition | -| `[TOOL_RESULTS]` | 8 | Tool results block | -| `[/TOOL_RESULTS]` | 9 | End of tool results | +| Token | ID | Purpose | +| -------------------- | --- | -------------------------- | +| `[INST]` | 3 | Start of user instruction | +| `[/INST]` | 4 | End of user instruction | +| `[TOOL_CALLS]` | 5 | Tool calling block | +| `[AVAILABLE_TOOLS]` | 6 | Available tools definition | +| `[/AVAILABLE_TOOLS]` | 7 | End of tools definition | +| `[TOOL_RESULTS]` | 8 | Tool results block | +| `[/TOOL_RESULTS]` | 9 | End of tool results | ### mistral_v3 (Mistral NeMo, Large 2, Pixtral) @@ -320,11 +320,11 @@ V3 uses a completely different tokenizer architecture: **Tekken** (Tiktoken-base #### Core Tokens -| Token | ID | Purpose | -| ------- | -- | ---------------------------- | -| `` | 0 | Unknown token | -| `` | 1 | Beginning of sequence (BOS) | -| `` | 2 | End of sequence (EOS) | +| Token | ID | Purpose | +| ------- | --- | --------------------------- | +| `` | 0 | Unknown token | +| `` | 1 | Beginning of sequence (BOS) | +| `` | 2 | End of sequence (EOS) | V3 includes the same control tokens as V2 (`[INST]`, `[/INST]`, etc.) but uses Tiktoken encoding instead of SentencePiece. @@ -366,6 +366,7 @@ V3 does NOT use SentencePiece - it uses Tiktoken (similar to o200k_base). ## Agent Token Categories > **Mistral Agent Token IDs**: The tables below show `mistral_v1 ID`. For V2 and V3: +> > - **V1**: Agent tokens start at 32,000 > - **V2**: Agent tokens start at 32,768 (add 768 to V1 IDs) > - **V3**: Agent tokens start at 131,072 @@ -402,9 +403,9 @@ The capital of France is Paris.<|im_end|> **Purpose**: Enable System 2 (slow, deliberate) reasoning similar to DeepSeek-R1 or OpenAI o1. | Token | cl100k ID | o200k ID | llama3 ID | deepseek_v3 ID | mistral_v1 ID | Description | -| -------------- | --------- | -------- | --------- | -------------- | ---------- | ------------------------ | -| `<\|think\|>` | 100282 | 200024 | 128305 | 128905 | 32005 | Start of reasoning block | -| `<\|/think\|>` | 100283 | 200025 | 128306 | 128906 | 32006 | End of reasoning block | +| -------------- | --------- | -------- | --------- | -------------- | ------------- | ------------------------ | +| `<\|think\|>` | 100282 | 200024 | 128305 | 128905 | 32005 | Start of reasoning block | +| `<\|/think\|>` | 100283 | 200025 | 128306 | 128906 | 32006 | End of reasoning block | **Rationale**: Chain-of-Thought (CoT) prompting significantly improves model performance on complex tasks. Dedicated thinking tokens allow: @@ -430,15 +431,15 @@ The capital of France is Paris. **Purpose**: Implement the ReAct (Reason + Act) paradigm for autonomous agents. | Token | cl100k ID | o200k ID | llama3 ID | deepseek_v3 ID | mistral_v1 ID | Description | -| ---------------- | --------- | -------- | --------- | -------------- | ------------------------------- | -| `<\|plan\|>` | 100284 | 200026 | 128307 | 128907 | 32007 | High-level strategy formulation | -| `<\|/plan\|>` | 100285 | 200027 | 128308 | 128908 | 32008 | End of plan | -| `<\|step\|>` | 100286 | 200028 | 128309 | 128909 | 32009 | Individual step within plan | -| `<\|/step\|>` | 100287 | 200029 | 128310 | 128910 | 32010 | End of step | -| `<\|act\|>` | 100288 | 200030 | 128311 | 128911 | 32011 | Action intent declaration | -| `<\|/act\|>` | 100289 | 200031 | 128312 | 128912 | 32012 | End of action | -| `<\|observe\|>` | 100290 | 200032 | 128313 | 128913 | 32013 | Environment feedback | -| `<\|/observe\|>` | 100291 | 200033 | 128314 | 128914 | 32014 | End of observation | +| ---------------- | --------- | -------- | --------- | -------------- | ------------- | ------------------------------- | +| `<\|plan\|>` | 100284 | 200026 | 128307 | 128907 | 32007 | High-level strategy formulation | +| `<\|/plan\|>` | 100285 | 200027 | 128308 | 128908 | 32008 | End of plan | +| `<\|step\|>` | 100286 | 200028 | 128309 | 128909 | 32009 | Individual step within plan | +| `<\|/step\|>` | 100287 | 200029 | 128310 | 128910 | 32010 | End of step | +| `<\|act\|>` | 100288 | 200030 | 128311 | 128911 | 32011 | Action intent declaration | +| `<\|/act\|>` | 100289 | 200031 | 128312 | 128912 | 32012 | End of action | +| `<\|observe\|>` | 100290 | 200032 | 128313 | 128913 | 32013 | Environment feedback | +| `<\|/observe\|>` | 100291 | 200033 | 128314 | 128914 | 32014 | End of observation | **Rationale**: The [ReAct paper](https://arxiv.org/abs/2210.03629) demonstrated that interleaving reasoning and acting improves agent performance. These tokens create a structured loop: @@ -471,13 +472,13 @@ The current temperature in London is 18°C with partly cloudy skies. **Purpose**: Structured tool use with explicit success/error handling. | Token | cl100k ID | o200k ID | llama3 ID | deepseek_v3 ID | mistral_v1 ID | Description | -| ----------------- | --------- | -------- | --------- | -------------- | --------------------------- | -| `<\|function\|>` | 100292 | 200034 | 128315 | 128915 | 32015 | Function call specification | -| `<\|/function\|>` | 100293 | 200035 | 128316 | 128916 | 32016 | End of function call | -| `<\|result\|>` | 100294 | 200036 | 128317 | 128917 | 32017 | Successful return value | -| `<\|/result\|>` | 100295 | 200037 | 128318 | 128918 | 32018 | End of result | -| `<\|error\|>` | 100296 | 200038 | 128319 | 128919 | 32019 | Execution error | -| `<\|/error\|>` | 100297 | 200039 | 128320 | 128920 | 32020 | End of error | +| ----------------- | --------- | -------- | --------- | -------------- | ------------- | --------------------------- | +| `<\|function\|>` | 100292 | 200034 | 128315 | 128915 | 32015 | Function call specification | +| `<\|/function\|>` | 100293 | 200035 | 128316 | 128916 | 32016 | End of function call | +| `<\|result\|>` | 100294 | 200036 | 128317 | 128917 | 32017 | Successful return value | +| `<\|/result\|>` | 100295 | 200037 | 128318 | 128918 | 32018 | End of result | +| `<\|error\|>` | 100296 | 200038 | 128319 | 128919 | 32019 | Execution error | +| `<\|/error\|>` | 100297 | 200039 | 128320 | 128920 | 32020 | End of error | **Rationale**: Function calling is fundamental to agent capabilities. Separating `<|act|>` (intent) from `<|function|>` (technical payload) allows: @@ -507,13 +508,13 @@ The `<|error|>` token is critical for robust agents—it signals that the previo **Purpose**: Jupyter notebook-style code interpreter flow. | Token | cl100k ID | o200k ID | llama3 ID | deepseek_v3 ID | mistral_v1 ID | Description | -| --------------- | --------- | -------- | --------- | -------------- | --------------------- | -| `<\|code\|>` | 100298 | 200040 | 128321 | 128921 | 32021 | Code block to execute | -| `<\|/code\|>` | 100299 | 200041 | 128322 | 128922 | 32022 | End of code block | -| `<\|output\|>` | 100300 | 200042 | 128323 | 128923 | 32023 | Execution output | -| `<\|/output\|>` | 100301 | 200043 | 128324 | 128924 | 32024 | End of output | -| `<\|lang\|>` | 100302 | 200044 | 128325 | 128925 | 32025 | Language identifier | -| `<\|/lang\|>` | 100303 | 200045 | 128326 | 128926 | 32026 | End of language tag | +| --------------- | --------- | -------- | --------- | -------------- | ------------- | --------------------- | +| `<\|code\|>` | 100298 | 200040 | 128321 | 128921 | 32021 | Code block to execute | +| `<\|/code\|>` | 100299 | 200041 | 128322 | 128922 | 32022 | End of code block | +| `<\|output\|>` | 100300 | 200042 | 128323 | 128923 | 32023 | Execution output | +| `<\|/output\|>` | 100301 | 200043 | 128324 | 128924 | 32024 | End of output | +| `<\|lang\|>` | 100302 | 200044 | 128325 | 128925 | 32025 | Language identifier | +| `<\|/lang\|>` | 100303 | 200045 | 128326 | 128926 | 32026 | End of language tag | **Rationale**: Code execution is a powerful agent capability. These tokens model the notebook paradigm: @@ -543,15 +544,15 @@ print(f"Area: {area:.2f}") **Purpose**: Retrieval-Augmented Generation with source attribution. | Token | cl100k ID | o200k ID | llama3 ID | deepseek_v3 ID | mistral_v1 ID | Description | -| ---------------- | --------- | -------- | --------- | -------------- | ----------------------- | -| `<\|context\|>` | 100304 | 200046 | 128327 | 128927 | 32027 | Retrieved context block | -| `<\|/context\|>` | 100305 | 200047 | 128328 | 128928 | 32028 | End of context | -| `<\|quote\|>` | 100306 | 200048 | 128329 | 128929 | 32029 | Direct quotation | -| `<\|/quote\|>` | 100307 | 200049 | 128330 | 128930 | 32030 | End of quote | -| `<\|cite\|>` | 100308 | 200050 | 128331 | 128931 | 32031 | Citation reference | -| `<\|/cite\|>` | 100309 | 200051 | 128332 | 128932 | 32032 | End of citation | -| `<\|source\|>` | 100310 | 200052 | 128333 | 128933 | 32033 | Source metadata | -| `<\|/source\|>` | 100311 | 200053 | 128334 | 128934 | 32034 | End of source | +| ---------------- | --------- | -------- | --------- | -------------- | ------------- | ----------------------- | +| `<\|context\|>` | 100304 | 200046 | 128327 | 128927 | 32027 | Retrieved context block | +| `<\|/context\|>` | 100305 | 200047 | 128328 | 128928 | 32028 | End of context | +| `<\|quote\|>` | 100306 | 200048 | 128329 | 128929 | 32029 | Direct quotation | +| `<\|/quote\|>` | 100307 | 200049 | 128330 | 128930 | 32030 | End of quote | +| `<\|cite\|>` | 100308 | 200050 | 128331 | 128931 | 32031 | Citation reference | +| `<\|/cite\|>` | 100309 | 200051 | 128332 | 128932 | 32032 | End of citation | +| `<\|source\|>` | 100310 | 200052 | 128333 | 128933 | 32033 | Source metadata | +| `<\|/source\|>` | 100311 | 200053 | 128334 | 128934 | 32034 | End of source | **Rationale**: RAG systems retrieve relevant documents to ground model responses. These tokens enable: @@ -582,11 +583,11 @@ population of approximately <|quote|>2,102,650 residents<|/quote|> **Purpose**: Long-term memory and state persistence across sessions. | Token | cl100k ID | o200k ID | llama3 ID | deepseek_v3 ID | mistral_v1 ID | Description | -| --------------- | --------- | -------- | --------- | -------------- | ------------------- | -| `<\|memory\|>` | 100312 | 200054 | 128335 | 128935 | 32035 | Store information | -| `<\|/memory\|>` | 100313 | 200055 | 128336 | 128936 | 32036 | End of memory block | -| `<\|recall\|>` | 100314 | 200056 | 128337 | 128937 | 32037 | Retrieved memory | -| `<\|/recall\|>` | 100315 | 200057 | 128338 | 128938 | 32038 | End of recall | +| --------------- | --------- | -------- | --------- | -------------- | ------------- | ------------------- | +| `<\|memory\|>` | 100312 | 200054 | 128335 | 128935 | 32035 | Store information | +| `<\|/memory\|>` | 100313 | 200055 | 128336 | 128936 | 32036 | End of memory block | +| `<\|recall\|>` | 100314 | 200056 | 128337 | 128937 | 32037 | Retrieved memory | +| `<\|/recall\|>` | 100315 | 200057 | 128338 | 128938 | 32038 | End of recall | **Rationale**: Persistent memory enables agents to: @@ -614,10 +615,10 @@ Hello Alice! Here's a brief answer: The capital of France is Paris. **Purpose**: Sequence control and formatting. | Token | cl100k ID | o200k ID | llama3 ID | deepseek_v3 ID | mistral_v1 ID | Description | -| ------------ | --------- | -------- | --------- | -------------- | --------------------------- | -| `<\|pad\|>` | 100316 | 200058 | 128339 | 128939 | 32039 | Padding for batch alignment | -| `<\|stop\|>` | 100317 | 200059 | 128340 | 128940 | 32040 | Generation stop signal | -| `<\|sep\|>` | 100318 | 200060 | 128341 | 128941 | 32041 | Segment separator | +| ------------ | --------- | -------- | --------- | -------------- | ------------- | --------------------------- | +| `<\|pad\|>` | 100316 | 200058 | 128339 | 128939 | 32039 | Padding for batch alignment | +| `<\|stop\|>` | 100317 | 200059 | 128340 | 128940 | 32040 | Generation stop signal | +| `<\|sep\|>` | 100318 | 200060 | 128341 | 128941 | 32041 | Segment separator | **Rationale**: These are utility tokens for training and inference: @@ -632,13 +633,13 @@ Hello Alice! Here's a brief answer: The capital of France is Paris. **Purpose**: Placeholders for non-text content. | Token | cl100k ID | o200k ID | llama3 ID | deepseek_v3 ID | mistral_v1 ID | Description | -| -------------- | --------- | -------- | --------- | -------------- | ------------- | -| `<\|image\|>` | 100319 | 200061 | 128256\* | 128942 | 32042 | Image content | -| `<\|/image\|>` | 100320 | 200062 | 128257 | 128943 | 32043 | End of image | -| `<\|audio\|>` | 100321 | 200063 | 128258 | 128944 | 32044 | Audio content | -| `<\|/audio\|>` | 100322 | 200064 | 128259 | 128945 | 32045 | End of audio | -| `<\|video\|>` | 100323 | 200065 | 128260 | 128946 | 32046 | Video content | -| `<\|/video\|>` | 100324 | 200066 | 128261 | 128947 | 32047 | End of video | +| -------------- | --------- | -------- | --------- | -------------- | ------------- | ------------- | +| `<\|image\|>` | 100319 | 200061 | 128256\* | 128942 | 32042 | Image content | +| `<\|/image\|>` | 100320 | 200062 | 128257 | 128943 | 32043 | End of image | +| `<\|audio\|>` | 100321 | 200063 | 128258 | 128944 | 32044 | Audio content | +| `<\|/audio\|>` | 100322 | 200064 | 128259 | 128945 | 32045 | End of audio | +| `<\|video\|>` | 100323 | 200065 | 128260 | 128946 | 32046 | Video content | +| `<\|/video\|>` | 100324 | 200066 | 128261 | 128947 | 32047 | End of video | \*Note: Llama 3's `<|image|>` token (128256) is aligned with the official Meta Llama 3.2-Vision token ID for compatibility. @@ -664,13 +665,13 @@ The image shows a sunset over the ocean with vibrant orange and purple colors. **Purpose**: Semantic layout for parsing structured documents. | Token | cl100k ID | o200k ID | llama3 ID | deepseek_v3 ID | mistral_v1 ID | Description | -| ---------------- | --------- | -------- | --------- | -------------- | ---------------------- | -| `<\|title\|>` | 100325 | 200067 | 128348 | 128948 | 32048 | Document/section title | -| `<\|/title\|>` | 100326 | 200068 | 128349 | 128949 | 32049 | End of title | -| `<\|section\|>` | 100327 | 200069 | 128350 | 128950 | 32050 | Semantic section | -| `<\|/section\|>` | 100328 | 200070 | 128351 | 128951 | 32051 | End of section | -| `<\|summary\|>` | 100329 | 200071 | 128352 | 128952 | 32052 | Content summary | -| `<\|/summary\|>` | 100330 | 200072 | 128353 | 128953 | 32053 | End of summary | +| ---------------- | --------- | -------- | --------- | -------------- | ------------- | ---------------------- | +| `<\|title\|>` | 100325 | 200067 | 128348 | 128948 | 32048 | Document/section title | +| `<\|/title\|>` | 100326 | 200068 | 128349 | 128949 | 32049 | End of title | +| `<\|section\|>` | 100327 | 200069 | 128350 | 128950 | 32050 | Semantic section | +| `<\|/section\|>` | 100328 | 200070 | 128351 | 128951 | 32051 | End of section | +| `<\|summary\|>` | 100329 | 200071 | 128352 | 128952 | 32052 | Content summary | +| `<\|/summary\|>` | 100330 | 200072 | 128353 | 128953 | 32053 | End of summary | **Rationale**: When processing structured documents (papers, reports, documentation), these tokens help: diff --git a/pyproject.toml b/pyproject.toml index fbece2d..0ad222e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "splintr-rs" -version = "0.9.0" +version = "0.9.1" description = "Fast Rust tokenizer (BPE + SentencePiece) with Python bindings" readme = "README.md" license = { text = "MIT" } diff --git a/python/splintr/__init__.py b/python/splintr/__init__.py index dfb3b52..9cc28ae 100644 --- a/python/splintr/__init__.py +++ b/python/splintr/__init__.py @@ -153,4 +153,4 @@ "MISTRAL_V2_AGENT_TOKENS", "MISTRAL_V3_AGENT_TOKENS", ] -__version__ = "0.9.0" +__version__ = "0.9.1" diff --git a/src/core/mod.rs b/src/core/mod.rs index 59dfc1a..3132cb2 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,10 +1,13 @@ //! Core tokenization engine for splintr. //! -//! This module contains the high-performance BPE tokenizer implementation with: +//! This module contains multi-backend tokenizer implementations: //! - Byte-pair encoding using a linked-list algorithm (O(N) merges vs O(N²) for vectors) +//! - SentencePiece unigram tokenizer with score-based tie-breaking +//! - WordPiece tokenizer for BERT-family models with `##` continuation prefix //! - Vocabulary loading from tiktoken format //! - UTF-8 safe streaming decoder for LLM output //! - Main tokenizer interface with LRU caching and Rayon parallelism +//! - Unified [`Tokenize`] trait for backend-agnostic usage //! //! # Architecture //! @@ -13,8 +16,8 @@ //! - [`Tokenizer`]: Main tokenizer struct with encoding/decoding API, LRU cache, //! and Aho-Corasick special token matching. Uses regexr backend by default, //! with optional PCRE2 backend via `.pcre2(true)` (requires `pcre2` feature). -//! - [`bpe`]: Low-level byte-pair encoding algorithm using linked-list approach -//! - [`vocab`]: Vocabulary loading utilities for tiktoken format +//! - `bpe`: Low-level byte-pair encoding algorithm using linked-list approach +//! - `vocab`: Vocabulary loading utilities for tiktoken format //! - [`StreamingDecoder`]: UTF-8 safe streaming decoder for token-by-token LLM output //! - [`ByteLevelStreamingDecoder`]: Streaming decoder for ByteLevel tokenizers (DeepSeek, GPT-2) //! @@ -32,8 +35,10 @@ pub mod byte_level; pub mod pretrained; pub mod sentencepiece; mod streaming; +pub mod tokenize; mod tokenizer; mod vocab; +pub mod wordpiece; pub use bpe::byte_pair_encode; pub use byte_level::{byte_level_decode, byte_level_decode_bytes, byte_level_encode}; @@ -45,8 +50,10 @@ pub use pretrained::{ }; pub use sentencepiece::{SentencePieceError, SentencePieceTokenizer}; pub use streaming::{ByteLevelStreamingDecoder, StreamingDecoder}; +pub use tokenize::{Tokenize, TokenizeError}; pub use tokenizer::{ cl100k_agent_tokens, o200k_agent_tokens, Tokenizer, TokenizerError, CL100K_BASE_PATTERN, LLAMA3_PATTERN, MISTRAL_V3_PATTERN, O200K_BASE_PATTERN, SENTENCEPIECE_PATTERN, }; pub use vocab::{build_decoder, load_tiktoken_bpe, load_tiktoken_bpe_file, VocabError}; +pub use wordpiece::WordPieceTokenizer; diff --git a/src/core/sentencepiece.rs b/src/core/sentencepiece.rs index f731290..df8d6a8 100644 --- a/src/core/sentencepiece.rs +++ b/src/core/sentencepiece.rs @@ -248,6 +248,21 @@ impl SentencePieceTokenizer { } } +impl super::tokenize::Tokenize for SentencePieceTokenizer { + fn encode(&self, text: &str) -> Vec { + self.encode(text) + } + + fn decode(&self, ids: &[u32]) -> Result { + self.decode(ids) + .map_err(|e| super::tokenize::TokenizeError::Other(e.to_string())) + } + + fn vocab_size(&self) -> usize { + self.vocab_size() + } +} + /// Parse a byte-fallback token like `<0x0A>` into its byte value. fn parse_byte_fallback(token: &str) -> Option { let inner = token.strip_prefix("<0x")?.strip_suffix('>')?; diff --git a/src/core/tokenize.rs b/src/core/tokenize.rs new file mode 100644 index 0000000..e554508 --- /dev/null +++ b/src/core/tokenize.rs @@ -0,0 +1,33 @@ +//! Unified tokenizer trait for all splintr backends. +//! +//! The `Tokenize` trait provides a common interface across BPE, SentencePiece, +//! and WordPiece tokenizers, enabling generic code that works with any backend. + +/// Common interface for all tokenizer backends. +/// +/// Implemented by [`Tokenizer`](super::Tokenizer) (BPE), +/// [`SentencePieceTokenizer`](super::SentencePieceTokenizer) (unigram), and +/// [`WordPieceTokenizer`](super::WordPieceTokenizer) (WordPiece). +pub trait Tokenize: Send + Sync { + /// Encode text into token IDs. + fn encode(&self, text: &str) -> Vec; + + /// Decode token IDs back to text. + /// + /// Returns an error if any token ID is invalid. + fn decode(&self, ids: &[u32]) -> Result; + + /// Return the vocabulary size (number of distinct tokens). + fn vocab_size(&self) -> usize; +} + +/// Error type for the [`Tokenize`] trait's decode method. +#[derive(Debug, thiserror::Error)] +pub enum TokenizeError { + #[error("Decoding error: invalid UTF-8")] + Utf8Error, + #[error("Decoding error: token ID {0} out of range")] + InvalidTokenId(u32), + #[error("{0}")] + Other(String), +} diff --git a/src/core/tokenizer.rs b/src/core/tokenizer.rs index 9dac75b..0a0e343 100644 --- a/src/core/tokenizer.rs +++ b/src/core/tokenizer.rs @@ -264,15 +264,15 @@ impl RegexBackend { /// /// This tokenizer is optimized for high throughput across different workloads: /// -/// - **Single text encoding**: Uses sequential processing via [`encode`]. +/// - **Single text encoding**: Uses sequential processing via [`Tokenizer::encode`]. /// Benchmarks show sequential is faster for texts up to ~1MB due to Rayon /// thread pool overhead. Sequential achieves ~50 MB/s consistently. /// -/// - **Batch encoding**: Uses Rayon parallelism via [`encode_batch`]. +/// - **Batch encoding**: Uses Rayon parallelism via [`Tokenizer::encode_batch`]. /// Parallelizes across texts (not within a single text), achieving ~110 MB/s /// on batch workloads - approximately 10-12x faster than tiktoken. /// -/// - **Very large single texts (>1MB)**: Use [`encode_rayon`] for texts larger +/// - **Very large single texts (>1MB)**: Use [`Tokenizer::encode_rayon`] for texts larger /// than ~1MB where Rayon parallelization within the text becomes beneficial. /// /// # Regex Backend @@ -1106,6 +1106,23 @@ impl Clone for Tokenizer { } } +impl super::tokenize::Tokenize for Tokenizer { + fn encode(&self, text: &str) -> Vec { + self.encode(text) + } + + fn decode(&self, ids: &[u32]) -> Result { + self.decode(ids).map_err(|e| match e { + TokenizerError::Utf8Error => super::tokenize::TokenizeError::Utf8Error, + other => super::tokenize::TokenizeError::Other(other.to_string()), + }) + } + + fn vocab_size(&self) -> usize { + self.vocab_size() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/core/wordpiece.rs b/src/core/wordpiece.rs new file mode 100644 index 0000000..c203ee0 --- /dev/null +++ b/src/core/wordpiece.rs @@ -0,0 +1,418 @@ +//! WordPiece tokenizer for BERT-family models. +//! +//! Implements the standard BERT tokenization pipeline: +//! 1. **BasicTokenizer**: lowercase, strip accents, split on whitespace and punctuation +//! 2. **WordPiece**: greedy longest-match subword tokenization with `##` continuation prefix +//! +//! Handles `[CLS]`, `[SEP]`, `[PAD]`, `[UNK]` special tokens. + +use super::tokenize::{Tokenize, TokenizeError}; +use std::collections::HashMap; + +/// WordPiece tokenizer compatible with BERT-family models. +/// +/// Constructed from a flat vocabulary list where index = token ID +/// (same format as GGUF `tokenizer.ggml.tokens`). +/// +/// # Example +/// +/// ``` +/// use splintr::{WordPieceTokenizer, Tokenize}; +/// +/// let vocab = vec![ +/// "[PAD]", "[UNK]", "[CLS]", "[SEP]", +/// "hello", "world", "##ing", "##s", +/// ].into_iter().map(String::from).collect(); +/// let tok = WordPieceTokenizer::new(vocab, 1, 200, true); +/// let ids = tok.encode("hello world"); +/// ``` +pub struct WordPieceTokenizer { + /// Token string → ID + token_to_id: HashMap, + /// ID → token string + id_to_token: Vec, + /// Token ID for unknown tokens + unk_token_id: u32, + /// Maximum characters in a single word before it's treated as [UNK] + max_word_len: usize, + /// Whether to lowercase and strip accents (for uncased models) + do_lower_case: bool, + /// Whether the vocabulary uses `##` prefix for continuation tokens. + /// If false, continuations are looked up without prefix (GGUF-stripped vocabs). + has_continuation_prefix: bool, + /// Special token IDs for [CLS], [SEP], [PAD] + cls_token_id: Option, + sep_token_id: Option, + pad_token_id: Option, +} + +impl WordPieceTokenizer { + /// Create a WordPiece tokenizer from a flat vocabulary. + /// + /// # Arguments + /// * `vocab` - Token strings indexed by token ID + /// * `unk_token_id` - ID to use for unknown tokens + /// * `max_word_len` - Words longer than this are mapped to `[UNK]` + /// * `do_lower_case` - Whether to lowercase and strip accents (for uncased models) + pub fn new( + vocab: Vec, + unk_token_id: u32, + max_word_len: usize, + do_lower_case: bool, + ) -> Self { + let mut token_to_id = HashMap::with_capacity(vocab.len()); + for (id, token) in vocab.iter().enumerate() { + token_to_id.insert(token.clone(), id as u32); + } + + // Auto-detect whether vocab uses ## prefix for continuations + let has_continuation_prefix = token_to_id.keys().any(|k| k.starts_with("##")); + + let cls_token_id = token_to_id.get("[CLS]").copied(); + let sep_token_id = token_to_id.get("[SEP]").copied(); + let pad_token_id = token_to_id.get("[PAD]").copied(); + + Self { + token_to_id, + id_to_token: vocab, + unk_token_id, + max_word_len, + do_lower_case, + has_continuation_prefix, + cls_token_id, + sep_token_id, + pad_token_id, + } + } + + /// Get the `[CLS]` token ID, if present in the vocabulary. + pub fn cls_token_id(&self) -> Option { + self.cls_token_id + } + + /// Get the `[SEP]` token ID, if present in the vocabulary. + pub fn sep_token_id(&self) -> Option { + self.sep_token_id + } + + /// Get the `[PAD]` token ID, if present in the vocabulary. + pub fn pad_token_id(&self) -> Option { + self.pad_token_id + } + + /// Get the `[UNK]` token ID. + pub fn unk_token_id(&self) -> u32 { + self.unk_token_id + } + + /// Pre-tokenize: lowercase, strip accents, split on whitespace and punctuation. + fn basic_tokenize(&self, text: &str) -> Vec { + let text = if self.do_lower_case { + let lowered = text.to_lowercase(); + strip_accents(&lowered) + } else { + text.to_string() + }; + + // Split on whitespace, then split each token on punctuation boundaries + let mut tokens = Vec::new(); + for word in text.split_whitespace() { + split_on_punctuation(word, &mut tokens); + } + tokens + } + + /// WordPiece: greedily match longest subword. + /// + /// If the vocabulary uses `##` prefix (standard HuggingFace format), + /// continuations are looked up with `##` prefix. Otherwise (GGUF-stripped + /// vocabs), continuations are looked up directly. + fn wordpiece_tokenize(&self, word: &str) -> Vec { + let chars: Vec = word.chars().collect(); + if chars.len() > self.max_word_len { + return vec![self.unk_token_id]; + } + + let mut ids = Vec::new(); + let mut start = 0; + + while start < chars.len() { + let mut end = chars.len(); + let mut found = false; + + while start < end { + let raw: String = chars[start..end].iter().collect(); + let lookup = if start == 0 || !self.has_continuation_prefix { + raw + } else { + format!("##{}", raw) + }; + + if let Some(&id) = self.token_to_id.get(&lookup) { + ids.push(id); + found = true; + start = end; + break; + } + + end -= 1; + } + + if !found { + ids.push(self.unk_token_id); + start += 1; + } + } + + ids + } +} + +impl Tokenize for WordPieceTokenizer { + fn encode(&self, text: &str) -> Vec { + let words = self.basic_tokenize(text); + let mut ids = Vec::new(); + + for word in &words { + let word_ids = self.wordpiece_tokenize(word); + ids.extend(word_ids); + } + + ids + } + + fn decode(&self, ids: &[u32]) -> Result { + if self.has_continuation_prefix { + self.decode_with_prefix(ids) + } else { + self.decode_without_prefix(ids) + } + } + + fn vocab_size(&self) -> usize { + self.id_to_token.len() + } +} + +impl WordPieceTokenizer { + /// Decode when vocab uses `##` prefix — use prefix presence to detect continuations. + fn decode_with_prefix(&self, ids: &[u32]) -> Result { + let mut pieces = Vec::with_capacity(ids.len()); + + for &id in ids { + let token = self + .id_to_token + .get(id as usize) + .ok_or(TokenizeError::InvalidTokenId(id))?; + + if is_special_token(token) { + continue; + } + + if let Some(stripped) = token.strip_prefix("##") { + pieces.push(stripped.to_string()); + } else { + if !pieces.is_empty() { + pieces.push(" ".to_string()); + } + pieces.push(token.to_string()); + } + } + + Ok(pieces.join("")) + } + + /// Decode when vocab has no `##` prefix (GGUF-stripped). + /// Without `##`, we can't distinguish continuations from word starts, + /// so we just join with spaces between each token. + fn decode_without_prefix(&self, ids: &[u32]) -> Result { + let mut parts = Vec::with_capacity(ids.len()); + + for &id in ids { + let token = self + .id_to_token + .get(id as usize) + .ok_or(TokenizeError::InvalidTokenId(id))?; + + if is_special_token(token) { + continue; + } + + parts.push(token.as_str()); + } + + Ok(parts.join(" ")) + } +} + +fn is_special_token(token: &str) -> bool { + matches!(token, "[CLS]" | "[SEP]" | "[PAD]" | "[UNK]" | "[MASK]") + || (token.starts_with("[unused") && token.ends_with(']')) +} + +/// Strip Unicode combining marks (accents) from text. +fn strip_accents(text: &str) -> String { + use unicode_normalization::UnicodeNormalization; + text.nfd() + .filter(|c| !unicode_normalization::char::is_combining_mark(*c)) + .collect() +} + +/// Split a word on punctuation boundaries, pushing results into `out`. +fn split_on_punctuation(word: &str, out: &mut Vec) { + let mut current = String::new(); + for c in word.chars() { + if is_punctuation(c) { + if !current.is_empty() { + out.push(std::mem::take(&mut current)); + } + out.push(c.to_string()); + } else { + current.push(c); + } + } + if !current.is_empty() { + out.push(current); + } +} + +/// Check if a character is punctuation (matching BERT's definition). +fn is_punctuation(c: char) -> bool { + // ASCII punctuation ranges + matches!(c, '\x21'..='\x2F' | '\x3A'..='\x40' | '\x5B'..='\x60' | '\x7B'..='\x7E') + || c.is_ascii_punctuation() + || { + // Unicode punctuation categories + let cat = unicode_general_category::get_general_category(c); + matches!( + cat, + unicode_general_category::GeneralCategory::ConnectorPunctuation + | unicode_general_category::GeneralCategory::DashPunctuation + | unicode_general_category::GeneralCategory::ClosePunctuation + | unicode_general_category::GeneralCategory::FinalPunctuation + | unicode_general_category::GeneralCategory::InitialPunctuation + | unicode_general_category::GeneralCategory::OtherPunctuation + | unicode_general_category::GeneralCategory::OpenPunctuation + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_tokenizer() -> WordPieceTokenizer { + let vocab = vec![ + "[PAD]".to_string(), // 0 + "[UNK]".to_string(), // 1 + "[CLS]".to_string(), // 2 + "[SEP]".to_string(), // 3 + "hello".to_string(), // 4 + "world".to_string(), // 5 + "##ing".to_string(), // 6 + "##s".to_string(), // 7 + "un".to_string(), // 8 + "##know".to_string(), // 9 + "##n".to_string(), // 10 + ",".to_string(), // 11 + "the".to_string(), // 12 + "a".to_string(), // 13 + ]; + WordPieceTokenizer::new(vocab, 1, 200, true) + } + + #[test] + fn test_encode_basic() { + let tok = make_tokenizer(); + let ids = tok.encode("hello world"); + assert_eq!(ids, vec![4, 5]); + } + + #[test] + fn test_encode_subwords() { + let tok = make_tokenizer(); + let ids = tok.encode("unknown"); + // "unknown" → "un" + "##know" + "##n" + assert_eq!(ids, vec![8, 9, 10]); + } + + #[test] + fn test_encode_punctuation() { + let tok = make_tokenizer(); + let ids = tok.encode("hello, world"); + // "hello" "," "world" + assert_eq!(ids, vec![4, 11, 5]); + } + + #[test] + fn test_decode_basic() { + let tok = make_tokenizer(); + let text = tok.decode(&[4, 5]).unwrap(); + assert_eq!(text, "hello world"); + } + + #[test] + fn test_decode_subwords() { + let tok = make_tokenizer(); + let text = tok.decode(&[8, 9, 10]).unwrap(); + assert_eq!(text, "unknown"); + } + + #[test] + fn test_decode_skips_special() { + let tok = make_tokenizer(); + let text = tok.decode(&[2, 4, 5, 3]).unwrap(); + assert_eq!(text, "hello world"); + } + + #[test] + fn test_vocab_size() { + let tok = make_tokenizer(); + assert_eq!(tok.vocab_size(), 14); + } + + #[test] + fn test_special_token_ids() { + let tok = make_tokenizer(); + assert_eq!(tok.cls_token_id(), Some(2)); + assert_eq!(tok.sep_token_id(), Some(3)); + assert_eq!(tok.pad_token_id(), Some(0)); + assert_eq!(tok.unk_token_id(), 1); + } + + #[test] + fn test_unknown_word() { + let tok = make_tokenizer(); + // "xyz" has no vocab entries → each char becomes [UNK] + let ids = tok.encode("xyz"); + assert!(ids.iter().all(|&id| id == 1)); + } + + #[test] + fn test_lowercase() { + let tok = make_tokenizer(); + let ids = tok.encode("Hello WORLD"); + assert_eq!(ids, vec![4, 5]); + } + + #[test] + fn test_case_sensitive() { + let vocab = vec![ + "[UNK]".to_string(), // 0 + "Hello".to_string(), // 1 + "hello".to_string(), // 2 + ]; + let tok = WordPieceTokenizer::new(vocab, 0, 200, false); + let ids = tok.encode("Hello"); + assert_eq!(ids, vec![1]); + let ids = tok.encode("hello"); + assert_eq!(ids, vec![2]); + } + + #[test] + fn test_decode_invalid_id() { + let tok = make_tokenizer(); + let result = tok.decode(&[999]); + assert!(result.is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs index 60026e9..78ab436 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,8 +4,8 @@ mod python; pub use core::{ ByteLevelStreamingDecoder, SentencePieceError, SentencePieceTokenizer, StreamingDecoder, - Tokenizer, TokenizerError, CL100K_BASE_PATTERN, LLAMA3_PATTERN, O200K_BASE_PATTERN, - SENTENCEPIECE_PATTERN, + Tokenize, TokenizeError, Tokenizer, TokenizerError, WordPieceTokenizer, CL100K_BASE_PATTERN, + LLAMA3_PATTERN, O200K_BASE_PATTERN, SENTENCEPIECE_PATTERN, }; // Re-export pretrained tokenizer API @@ -17,7 +17,7 @@ pub use core::{ PretrainedVocab, }; -/// Splintr - Fast Rust tokenizer (BPE + SentencePiece) with Python bindings +/// Splintr - Fast Rust tokenizer (BPE + SentencePiece + WordPiece) with Python bindings /// /// A high-performance tokenizer featuring: /// - Regexr with JIT and SIMD (default, pure Rust) @@ -25,6 +25,7 @@ pub use core::{ /// - Rayon parallelism for multi-core encoding /// - Linked-list BPE algorithm (avoids O(N²) on pathological inputs) /// - SentencePiece unigram with greedy longest-match and score-based tie-breaking +/// - WordPiece tokenizer for BERT-family models with `##` continuation prefix /// - FxHashMap for fast lookups /// - Aho-Corasick for fast special token matching /// - LRU cache for frequently encoded chunks diff --git a/uv.lock b/uv.lock index 8774932..cef1c7a 100644 --- a/uv.lock +++ b/uv.lock @@ -3,5 +3,5 @@ requires-python = ">=3.8" [[package]] name = "splintr-rs" -version = "0.9.0" +version = "0.9.1" source = { editable = "." }