Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions openviking/models/embedder/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
# SPDX-License-Identifier: Apache-2.0
import logging
import random
import threading
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
Expand All @@ -9,6 +11,89 @@
T = TypeVar("T")


_tiktoken_encoder = None
_tiktoken_lock = threading.Lock()
_TIKTOKEN_NOT_AVAILABLE = object() # sentinel: initialization was attempted but failed


def _get_tiktoken_encoder():
"""Get cached tiktoken encoder (module-level singleton, downloaded once).

Returns None if tiktoken is unavailable. The unavailable state is cached so
that import is only attempted once and the warning is only logged once.
"""
global _tiktoken_encoder
if _tiktoken_encoder is None:
with _tiktoken_lock:
if _tiktoken_encoder is None:
try:
import tiktoken

_tiktoken_encoder = tiktoken.get_encoding("cl100k_base")
except Exception as e:
logging.getLogger(__name__).warning(
f"tiktoken unavailable, falling back to byte-based truncation: {e}"
)
_tiktoken_encoder = _TIKTOKEN_NOT_AVAILABLE
return None if _tiktoken_encoder is _TIKTOKEN_NOT_AVAILABLE else _tiktoken_encoder


def truncate_text_by_tokens(text: str, max_tokens: int) -> str:
"""Truncate text to at most max_tokens tokens. Returns original text if within limit.

Uses cl100k_base (OpenAI BPE) as a universal tokenizer approximation.
This is exact for OpenAI embedding models, but only approximate for others:

- BGE / bge-m3: based on XLM-RoBERTa + SentencePiece; CJK text produces
~1.76 tokens/character, whereas cl100k_base merges CJK characters into
larger tokens and may undercount by up to ~3-7x for Chinese-heavy input.
- Jina / Voyage: also use their own tokenizers; token counts may differ.

For non-OpenAI models, callers should set a conservative max_tokens limit
(e.g. well below the model's hard limit) to absorb the tokenizer gap.

Falls back to UTF-8 byte truncation if tiktoken is unavailable.

Args:
text: Input text to truncate
max_tokens: Maximum number of tokens allowed

Returns:
Truncated text (or original text if already within limit)
"""
# Fast path: for any tokenizer, token count <= UTF-8 byte count
# (each token covers at least 1 byte), so byte count <= max_tokens
# guarantees token count <= max_tokens for ALL models including BGE.
# When this check fails and we fall through to cl100k_base, the token
# count estimate may be too low for non-OpenAI models (e.g. BGE
# SentencePiece produces ~1.76 tokens/CJK char while cl100k_base merges
# several chars into one token). In practice this is mitigated because:
# - bge-large-zh has a 500-token limit (~170 CJK chars), so nearly all
# inputs that reach this function will be caught by the fast path.
# - bge-m3 has an 8000-token limit with a wide enough margin.
# Callers should set conservative max_tokens limits for non-OpenAI models.
if len(text.encode("utf-8")) <= max_tokens:
return text

enc = _get_tiktoken_encoder()
if enc is not None:
try:
# disallowed_special=() avoids ValueError when text contains special token strings
tokens = enc.encode(text, disallowed_special=())
if len(tokens) <= max_tokens:
return text
return enc.decode(tokens[:max_tokens])
except Exception:
pass # Fall through to byte-based truncation

# Fallback: UTF-8 byte truncation (tiktoken unavailable).
# Guaranteed safe: token_count <= utf8_bytes, truncating to max_tokens bytes
# ensures token_count <= max_tokens.
encoded = text.encode("utf-8")
truncated = encoded[:max_tokens].decode("utf-8", errors="ignore")
return truncated


def truncate_and_normalize(embedding: List[float], dimension: Optional[int]) -> List[float]:
"""Truncate and L2 normalize embedding vector

Expand Down Expand Up @@ -100,6 +185,20 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes
"""
return [self.embed(text, is_query=is_query) for text in texts]

@property
def max_input_tokens(self) -> int:
"""Maximum number of tokens allowed as input. Subclasses can override."""
return 8000

def _truncate_input(self, text: str) -> str:
"""Truncate input text to max_input_tokens. Logs a warning if truncation occurs."""
truncated = truncate_text_by_tokens(text, self.max_input_tokens)
if len(truncated) < len(text):
logging.getLogger(__name__).warning(
f"[{self.__class__.__name__}] Input truncated to {self.max_input_tokens} tokens"
)
return truncated

def close(self):
"""Release resources, subclasses can override as needed"""
pass
Expand Down
13 changes: 13 additions & 0 deletions openviking/models/embedder/jina_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
"jina-embeddings-v5-text-nano": 768, # 239M params, max seq 8192
}

# Max input tokens for Jina embedding models (with ~200 token buffer for tokenizer differences)
JINA_MODEL_MAX_TOKENS = {
"jina-embeddings-v5-text-small": 32000,
"jina-embeddings-v5-text-nano": 8000,
}


class JinaDenseEmbedder(DenseEmbedderBase):
"""Jina AI Dense Embedder Implementation
Expand Down Expand Up @@ -106,6 +112,11 @@ def __init__(
f"Jina models support Matryoshka dimension reduction up to {max_dim}."
)
self._dimension = dimension if dimension is not None else max_dim
self._max_input_tokens = JINA_MODEL_MAX_TOKENS.get(model_name, 8192)

@property
def max_input_tokens(self) -> int:
return self._max_input_tokens

def _build_extra_body(self, is_query: bool = False) -> Optional[Dict[str, Any]]:
"""Build extra_body dict for Jina-specific parameters"""
Expand Down Expand Up @@ -136,6 +147,7 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
RuntimeError: When API call fails
"""
try:
text = self._truncate_input(text)
kwargs: Dict[str, Any] = {"input": text, "model": self.model_name}
if self.dimension:
kwargs["dimensions"] = self.dimension
Expand Down Expand Up @@ -170,6 +182,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes
return []

try:
texts = [self._truncate_input(t) for t in texts]
kwargs: Dict[str, Any] = {"input": texts, "model": self.model_name}
if self.dimension:
kwargs["dimensions"] = self.dimension
Expand Down
2 changes: 2 additions & 0 deletions openviking/models/embedder/openai_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
RuntimeError: When API call fails
"""
try:
text = self._truncate_input(text)
kwargs: Dict[str, Any] = {"input": text, "model": self.model_name}

extra_body = self._build_extra_body(is_query=is_query)
Expand Down Expand Up @@ -258,6 +259,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes
return []

try:
texts = [self._truncate_input(t) for t in texts]
kwargs: Dict[str, Any] = {"input": texts, "model": self.model_name}
if self.dimension:
kwargs["dimensions"] = self.dimension
Expand Down
39 changes: 39 additions & 0 deletions openviking/models/embedder/vikingdb_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@
from openviking.storage.vectordb.collection.volcengine_clients import ClientForDataApi
from openviking_cli.utils.logger import default_logger as logger

# Max input tokens per VikingDB model (文本截断长度, with small buffer)
VIKINGDB_MODEL_MAX_TOKENS = {
"bge-large-zh": 500,
"bge-m3": 8000,
"bge-visualized-m3": 8000,
"doubao-embedding-large": 4000,
"doubao-embedding-vision": 8000,
"doubao-embedding": 4000,
}


class VikingDBClientMixin:
"""Mixin to handle VikingDB Client initialization and API calls."""
Expand Down Expand Up @@ -81,6 +91,14 @@ def _truncate_and_normalize(
embedding = [x / norm for x in embedding]
return embedding

def _get_max_input_tokens(self) -> int:
"""Resolve max input tokens based on model name."""
name = self.model_name.lower()
for key, limit in VIKINGDB_MODEL_MAX_TOKENS.items():
if key in name:
return limit
return 4000 # conservative default

def _process_sparse_embedding(self, sparse_data: Any) -> Dict[str, float]:
"""Process sparse embedding data"""
if not sparse_data:
Expand All @@ -104,6 +122,10 @@ def _process_sparse_embedding(self, sparse_data: Any) -> Dict[str, float]:
class VikingDBDenseEmbedder(DenseEmbedderBase, VikingDBClientMixin):
"""VikingDB Dense Embedder"""

@property
def max_input_tokens(self) -> int:
return self._max_input_tokens

def __init__(
self,
model_name: str,
Expand All @@ -122,8 +144,10 @@ def __init__(
self.dimension = dimension
self.embedding_type = embedding_type
self.dense_model = {"name": model_name, "version": model_version, "dim": dimension}
self._max_input_tokens = self._get_max_input_tokens()

def embed(self, text: str, is_query: bool = False) -> EmbedResult:
text = self._truncate_input(text)
results = self._call_api([text], dense_model=self.dense_model)
if not results:
return EmbedResult(dense_vector=[])
Expand All @@ -138,6 +162,7 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]:
if not texts:
return []
texts = [self._truncate_input(t) for t in texts]
raw_results = self._call_api(texts, dense_model=self.dense_model)
return [
EmbedResult(
Expand All @@ -155,6 +180,10 @@ def get_dimension(self) -> int:
class VikingDBSparseEmbedder(SparseEmbedderBase, VikingDBClientMixin):
"""VikingDB Sparse Embedder"""

@property
def max_input_tokens(self) -> int:
return self._max_input_tokens

def __init__(
self,
model_name: str,
Expand All @@ -168,12 +197,14 @@ def __init__(
SparseEmbedderBase.__init__(self, model_name, config)
self._init_vikingdb_client(ak, sk, region, host)
self.model_version = model_version
self._max_input_tokens = self._get_max_input_tokens()
self.sparse_model = {
"name": model_name,
"version": model_version,
}

def embed(self, text: str, is_query: bool = False) -> EmbedResult:
text = self._truncate_input(text)
results = self._call_api([text], sparse_model=self.sparse_model)
if not results:
return EmbedResult(sparse_vector={})
Expand All @@ -188,6 +219,7 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]:
if not texts:
return []
texts = [self._truncate_input(t) for t in texts]
raw_results = self._call_api(texts, sparse_model=self.sparse_model)
return [
EmbedResult(
Expand All @@ -200,6 +232,10 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes
class VikingDBHybridEmbedder(HybridEmbedderBase, VikingDBClientMixin):
"""VikingDB Hybrid Embedder"""

@property
def max_input_tokens(self) -> int:
return self._max_input_tokens

def __init__(
self,
model_name: str,
Expand All @@ -217,13 +253,15 @@ def __init__(
self.model_version = model_version
self.dimension = dimension
self.embedding_type = embedding_type
self._max_input_tokens = self._get_max_input_tokens()
self.dense_model = {"name": model_name, "version": model_version, "dim": dimension}
self.sparse_model = {
"name": model_name,
"version": model_version,
}

def embed(self, text: str, is_query: bool = False) -> EmbedResult:
text = self._truncate_input(text)
results = self._call_api(
[text], dense_model=self.dense_model, sparse_model=self.sparse_model
)
Expand All @@ -244,6 +282,7 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]:
if not texts:
return []
texts = [self._truncate_input(t) for t in texts]
raw_results = self._call_api(
texts, dense_model=self.dense_model, sparse_model=self.sparse_model
)
Expand Down
19 changes: 19 additions & 0 deletions openviking/models/embedder/volcengine_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ class VolcengineDenseEmbedder(DenseEmbedderBase):
Supports Volcengine embedding models such as doubao-embedding.
"""

@property
def max_input_tokens(self) -> int:
return 8000 if "vision" in self.model_name.lower() else 4000

def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -159,6 +163,8 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
RuntimeError: When API call fails
"""

text = self._truncate_input(text)

def _embed_call():
if self.input_type == "multimodal":
# Use multimodal embeddings API
Expand Down Expand Up @@ -206,6 +212,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes
return []

try:
texts = [self._truncate_input(t) for t in texts]
if self.input_type == "multimodal":
multimodal_inputs = [{"type": "text", "text": text} for text in texts]
response = self.client.multimodal_embeddings.create(
Expand Down Expand Up @@ -238,6 +245,10 @@ class VolcengineSparseEmbedder(SparseEmbedderBase):
Generates sparse embeddings using Volcengine's multimodal embedding API.
"""

@property
def max_input_tokens(self) -> int:
return 8000 if "vision" in self.model_name.lower() else 4000

def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -283,6 +294,8 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
RuntimeError: When API call fails
"""

text = self._truncate_input(text)

def _embed_call():
# Must use multimodal endpoint for sparse
response = self.client.multimodal_embeddings.create(
Expand Down Expand Up @@ -332,6 +345,10 @@ class VolcengineHybridEmbedder(HybridEmbedderBase):
multimodal embedding API.
"""

@property
def max_input_tokens(self) -> int:
return 8000 if "vision" in self.model_name.lower() else 4000

def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -383,6 +400,8 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
RuntimeError: When API call fails
"""

text = self._truncate_input(text)

def _embed_call():
# Always use multimodal for hybrid to get both

Expand Down
2 changes: 2 additions & 0 deletions openviking/models/embedder/voyage_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
def embed(self, text: str, is_query: bool = False) -> EmbedResult:
"""Perform dense embedding on text."""
try:
text = self._truncate_input(text)
kwargs: Dict[str, Any] = {"input": text, "model": self.model_name}
if self.dimension is not None:
kwargs["extra_body"] = {"output_dimension": self.dimension}
Expand All @@ -102,6 +103,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes
return []

try:
texts = [self._truncate_input(t) for t in texts]
kwargs: Dict[str, Any] = {"input": texts, "model": self.model_name}
if self.dimension is not None:
kwargs["extra_body"] = {"output_dimension": self.dimension}
Expand Down
Loading
Loading