Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ TIDB_VECTOR_PORT=4000
TIDB_VECTOR_USER=xxx.root
TIDB_VECTOR_PASSWORD=xxxxxx
TIDB_VECTOR_DATABASE=dify
TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH=false

# Tidb on qdrant configuration
TIDB_ON_QDRANT_URL=http://127.0.0.1
Expand Down
5 changes: 5 additions & 0 deletions api/configs/middleware/vdb/tidb_vector_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ class TiDBVectorConfig(BaseSettings):
description="Name of the TiDB Vector database to connect to",
default=None,
)

TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH: bool = Field(
description="Enable TiDB Vector full-text and hybrid search features",
default=False,
)
4 changes: 3 additions & 1 deletion api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
# Define vector database types that only support semantic search
semantic_only_types = {
VectorType.RELYT,
VectorType.TIDB_VECTOR,
VectorType.CHROMA,
VectorType.PGVECTO_RS,
VectorType.VIKINGDB,
Expand Down Expand Up @@ -383,6 +382,9 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
if vector_type == VectorType.MILVUS:
return semantic_methods if is_mock else full_methods

if vector_type == VectorType.TIDB_VECTOR:
return full_methods if dify_config.TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH else semantic_methods

if vector_type in semantic_only_types:
return semantic_methods
elif vector_type in full_search_types:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class TiDBVectorConfig(BaseModel):
password: str
database: str
program_name: str
enable_fulltext_search: bool = False

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -99,6 +100,11 @@ def _create_collection(self, dimension: int):
if redis_client.get(collection_exist_cache_key):
return
tidb_dist_func = self._get_distance_func()
fulltext_index_statement = (
",\n FULLTEXT INDEX idx_text (text) WITH PARSER MULTILINGUAL"
if self._client_config.enable_fulltext_search
else ""
)
with sessionmaker(bind=self._engine).begin() as session:
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} (
Expand All @@ -113,6 +119,7 @@ def _create_collection(self, dimension: int):
KEY (doc_id),
KEY (document_id),
VECTOR INDEX idx_vector (({tidb_dist_func}(vector))) USING HNSW
{fulltext_index_statement}
);
""")
session.execute(create_statement)
Expand Down Expand Up @@ -241,8 +248,49 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc

@override
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# tidb doesn't support bm25 search
return []
if not self._client_config.enable_fulltext_search or not query:
return []

top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
document_ids_filter = kwargs.get("document_ids_filter")

where_conditions = ["FTS_MATCH_WORD(text, :query)"]
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_conditions.append(f"meta->>'$.document_id' in ({document_ids})")
where_clause = " AND ".join(where_conditions)

docs = []
with Session(self._engine) as session:
select_statement = sql_text(f"""
SELECT meta, text, score
FROM (
SELECT
meta,
text,
FTS_MATCH_WORD(text, :query) AS score
FROM {self._collection_name}
WHERE {where_clause}
ORDER BY score DESC
LIMIT :top_k
) t
WHERE score >= :score_threshold
""")
res = session.execute(
select_statement,
params={
"query": query,
"score_threshold": score_threshold,
"top_k": top_k,
},
)
results = [(row[0], row[1], row[2]) for row in res]
for meta, text, score in results:
metadata = parse_metadata_json(meta)
metadata["score"] = score
docs.append(Document(page_content=text, metadata=metadata))
return docs

@override
def delete(self):
Expand Down Expand Up @@ -280,5 +328,6 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
password=dify_config.TIDB_VECTOR_PASSWORD or "",
database=dify_config.TIDB_VECTOR_DATABASE or "",
program_name=dify_config.APPLICATION_NAME,
enable_fulltext_search=dify_config.TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH,
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def _config(tidb_module):
password="secret",
database="dify",
program_name="dify-app",
enable_fulltext_search=False,
)


Expand Down Expand Up @@ -151,15 +152,49 @@ def __exit__(self, exc_type, exc, tb):
vector._collection_name = "collection_1"
vector._engine = MagicMock()
vector._distance_func = "l2"
vector._client_config = _config(tidb_module)

vector._create_collection(3)

sql = str(session.execute.call_args.args[0])
assert "VECTOR<FLOAT>(3)" in sql
assert "VEC_L2_DISTANCE" in sql
assert "FULLTEXT INDEX" not in sql
tidb_module.redis_client.set.assert_called_once()


def test_create_collection_adds_fulltext_index_when_enabled(tidb_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock))
monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock())

session = MagicMock()

class _BeginCtx:
def __enter__(self):
return session

def __exit__(self, exc_type, exc, tb):
return False

mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)

vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
vector._distance_func = "cosine"
vector._client_config = _config(tidb_module).model_copy(update={"enable_fulltext_search": True})

vector._create_collection(3)

sql = str(session.execute.call_args.args[0])
assert "FULLTEXT INDEX idx_text (text) WITH PARSER MULTILINGUAL" in sql


def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch: pytest.MonkeyPatch):
class _InsertStmt:
def __init__(self, table):
Expand Down Expand Up @@ -215,10 +250,38 @@ def __exit__(self, exc_type, exc, tb):
return vector, session, tidb_module


# 1. search_by_full_text returns empty
def test_search_by_full_text_returns_empty(tidb_vector_with_session):
vector, _, _ = tidb_vector_with_session
# 1. search_by_full_text returns empty when disabled
def test_search_by_full_text_returns_empty_when_disabled(tidb_vector_with_session):
vector, session, tidb_module = tidb_vector_with_session
vector._client_config = _config(tidb_module)
assert vector.search_by_full_text("query") == []
session.execute.assert_not_called()


def test_search_by_full_text_queries_tidb_fts_and_scores(tidb_vector_with_session):
vector, session, tidb_module = tidb_vector_with_session
vector._client_config = _config(tidb_module).model_copy(update={"enable_fulltext_search": True})
session.execute.return_value = [
('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.8),
('{"doc_id":"id-2","document_id":"d-2"}', "text-2", 0.6),
]

docs = vector.search_by_full_text(
"search query",
top_k=2,
score_threshold=0.5,
document_ids_filter=["d-1", "d-2"],
)

assert len(docs) == 2
assert docs[0].page_content == "text-1"
assert docs[0].metadata["score"] == pytest.approx(0.8)
assert docs[1].metadata["score"] == pytest.approx(0.6)
sql = str(session.execute.call_args.args[0])
params = session.execute.call_args.kwargs["params"]
assert "FTS_MATCH_WORD(text, :query)" in sql
assert "meta->>'$.document_id' in ('d-1', 'd-2')" in sql
assert params == {"query": "search query", "score_threshold": 0.5, "top_k": 2}


# 2. text_exists returns True when ids found
Expand Down Expand Up @@ -428,6 +491,7 @@ def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeyp
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_USER", "root")
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PASSWORD", "secret")
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_DATABASE", "dify")
monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH", True)
monkeypatch.setattr(tidb_module.dify_config, "APPLICATION_NAME", "dify-app")

with patch.object(tidb_module, "TiDBVector", return_value="vector") as vector_cls:
Expand All @@ -438,4 +502,5 @@ def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeyp
assert result_2 == "vector"
assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection"
assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection"
assert vector_cls.call_args_list[0].kwargs["config"].enable_fulltext_search is True
assert dataset_without_index.index_struct is not None
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
DatasetRetrievalSettingApi,
DatasetRetrievalSettingMockApi,
DatasetUseCheckApi,
_get_retrieval_methods_by_vector_type,
)
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.storage.storage_type import StorageType
from models.account import Account, TenantAccountRole
from models.dataset import Dataset, DatasetQuery, Document
Expand Down Expand Up @@ -1989,6 +1992,28 @@ def test_get_success(self, app: Flask):

assert "retrieval_method" in response

def test_tidb_vector_returns_semantic_only_when_fulltext_disabled(self):
with patch(
"controllers.console.datasets.datasets.dify_config.TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH",
False,
):
response = _get_retrieval_methods_by_vector_type(VectorType.TIDB_VECTOR)

assert response["retrieval_method"] == [RetrievalMethod.SEMANTIC_SEARCH.value]

def test_tidb_vector_returns_full_methods_when_fulltext_enabled(self):
with patch(
"controllers.console.datasets.datasets.dify_config.TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH",
True,
):
response = _get_retrieval_methods_by_vector_type(VectorType.TIDB_VECTOR)

assert response["retrieval_method"] == [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]


class TestDatasetRetrievalSettingMockApi:
def test_get_success(self, app: Flask):
Expand Down
1 change: 1 addition & 0 deletions docker/envs/core-services/shared.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ TIDB_VECTOR_HOST=tidb
TIDB_VECTOR_PORT=4000
TIDB_VECTOR_USER=
TIDB_VECTOR_PASSWORD=
TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH=false
TIDB_ON_QDRANT_CLIENT_TIMEOUT=20
TIDB_ON_QDRANT_GRPC_ENABLED=false
TIDB_ON_QDRANT_GRPC_PORT=6334
Expand Down