From 869f614f7b517701a281e1461170810baadae06f Mon Sep 17 00:00:00 2001 From: Baisu Date: Mon, 29 Jun 2026 00:01:20 +0800 Subject: [PATCH] support tidb vector fulltext search --- api/.env.example | 1 + .../middleware/vdb/tidb_vector_config.py | 5 ++ api/controllers/console/datasets/datasets.py | 4 +- .../src/dify_vdb_tidb_vector/tidb_vector.py | 53 +++++++++++++- .../tests/unit_tests/test_tidb_vector.py | 71 ++++++++++++++++++- .../console/datasets/test_datasets.py | 25 +++++++ docker/envs/core-services/shared.env.example | 1 + 7 files changed, 154 insertions(+), 6 deletions(-) diff --git a/api/.env.example b/api/.env.example index 48d8707d1ad3c0..92cedec64ff1f0 100644 --- a/api/.env.example +++ b/api/.env.example @@ -322,6 +322,7 @@ TIDB_VECTOR_PORT=4000 TIDB_VECTOR_USER=xxx.root TIDB_VECTOR_PASSWORD=xxxxxx TIDB_VECTOR_DATABASE=dify +TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH=false # Tidb on qdrant configuration TIDB_ON_QDRANT_URL=http://127.0.0.1 diff --git a/api/configs/middleware/vdb/tidb_vector_config.py b/api/configs/middleware/vdb/tidb_vector_config.py index 0ebf226bea665b..172b5a7e56efc0 100644 --- a/api/configs/middleware/vdb/tidb_vector_config.py +++ b/api/configs/middleware/vdb/tidb_vector_config.py @@ -31,3 +31,8 @@ class TiDBVectorConfig(BaseSettings): description="Name of the TiDB Vector database to connect to", default=None, ) + + TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH: bool = Field( + description="Enable TiDB Vector full-text and hybrid search features", + default=False, + ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 70ce54830c7c5f..953e5256d32d08 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -335,7 +335,6 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool # Define vector database types that only support semantic search semantic_only_types = { VectorType.RELYT, - VectorType.TIDB_VECTOR, VectorType.CHROMA, VectorType.PGVECTO_RS, VectorType.VIKINGDB, @@ -383,6 +382,9 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool if vector_type == VectorType.MILVUS: return semantic_methods if is_mock else full_methods + if vector_type == VectorType.TIDB_VECTOR: + return full_methods if dify_config.TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH else semantic_methods + if vector_type in semantic_only_types: return semantic_methods elif vector_type in full_search_types: diff --git a/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py b/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py index 9f80ae5a76fed6..7bdb13c6636388 100644 --- a/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py +++ b/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py @@ -28,6 +28,7 @@ class TiDBVectorConfig(BaseModel): password: str database: str program_name: str + enable_fulltext_search: bool = False @model_validator(mode="before") @classmethod @@ -99,6 +100,11 @@ def _create_collection(self, dimension: int): if redis_client.get(collection_exist_cache_key): return tidb_dist_func = self._get_distance_func() + fulltext_index_statement = ( + ",\n FULLTEXT INDEX idx_text (text) WITH PARSER MULTILINGUAL" + if self._client_config.enable_fulltext_search + else "" + ) with sessionmaker(bind=self._engine).begin() as session: create_statement = sql_text(f""" CREATE TABLE IF NOT EXISTS {self._collection_name} ( @@ -113,6 +119,7 @@ def _create_collection(self, dimension: int): KEY (doc_id), KEY (document_id), VECTOR INDEX idx_vector (({tidb_dist_func}(vector))) USING HNSW + {fulltext_index_statement} ); """) session.execute(create_statement) @@ -241,8 +248,49 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc @override def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - # tidb doesn't support bm25 search - return [] + if not self._client_config.enable_fulltext_search or not query: + return [] + + top_k = kwargs.get("top_k", 4) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + document_ids_filter = kwargs.get("document_ids_filter") + + where_conditions = ["FTS_MATCH_WORD(text, :query)"] + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_conditions.append(f"meta->>'$.document_id' in ({document_ids})") + where_clause = " AND ".join(where_conditions) + + docs = [] + with Session(self._engine) as session: + select_statement = sql_text(f""" + SELECT meta, text, score + FROM ( + SELECT + meta, + text, + FTS_MATCH_WORD(text, :query) AS score + FROM {self._collection_name} + WHERE {where_clause} + ORDER BY score DESC + LIMIT :top_k + ) t + WHERE score >= :score_threshold + """) + res = session.execute( + select_statement, + params={ + "query": query, + "score_threshold": score_threshold, + "top_k": top_k, + }, + ) + results = [(row[0], row[1], row[2]) for row in res] + for meta, text, score in results: + metadata = parse_metadata_json(meta) + metadata["score"] = score + docs.append(Document(page_content=text, metadata=metadata)) + return docs @override def delete(self): @@ -280,5 +328,6 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings password=dify_config.TIDB_VECTOR_PASSWORD or "", database=dify_config.TIDB_VECTOR_DATABASE or "", program_name=dify_config.APPLICATION_NAME, + enable_fulltext_search=dify_config.TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH, ), ) diff --git a/api/providers/vdb/vdb-tidb-vector/tests/unit_tests/test_tidb_vector.py b/api/providers/vdb/vdb-tidb-vector/tests/unit_tests/test_tidb_vector.py index ed03cbee88d260..0c6c51ae2819fe 100644 --- a/api/providers/vdb/vdb-tidb-vector/tests/unit_tests/test_tidb_vector.py +++ b/api/providers/vdb/vdb-tidb-vector/tests/unit_tests/test_tidb_vector.py @@ -25,6 +25,7 @@ def _config(tidb_module): password="secret", database="dify", program_name="dify-app", + enable_fulltext_search=False, ) @@ -151,15 +152,49 @@ def __exit__(self, exc_type, exc, tb): vector._collection_name = "collection_1" vector._engine = MagicMock() vector._distance_func = "l2" + vector._client_config = _config(tidb_module) vector._create_collection(3) sql = str(session.execute.call_args.args[0]) assert "VECTOR(3)" in sql assert "VEC_L2_DISTANCE" in sql + assert "FULLTEXT INDEX" not in sql tidb_module.redis_client.set.assert_called_once() +def test_create_collection_adds_fulltext_index_when_enabled(tidb_module, monkeypatch: pytest.MonkeyPatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock()) + + session = MagicMock() + + class _BeginCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx())) + monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm) + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector._distance_func = "cosine" + vector._client_config = _config(tidb_module).model_copy(update={"enable_fulltext_search": True}) + + vector._create_collection(3) + + sql = str(session.execute.call_args.args[0]) + assert "FULLTEXT INDEX idx_text (text) WITH PARSER MULTILINGUAL" in sql + + def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch: pytest.MonkeyPatch): class _InsertStmt: def __init__(self, table): @@ -215,10 +250,38 @@ def __exit__(self, exc_type, exc, tb): return vector, session, tidb_module -# 1. search_by_full_text returns empty -def test_search_by_full_text_returns_empty(tidb_vector_with_session): - vector, _, _ = tidb_vector_with_session +# 1. search_by_full_text returns empty when disabled +def test_search_by_full_text_returns_empty_when_disabled(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + vector._client_config = _config(tidb_module) assert vector.search_by_full_text("query") == [] + session.execute.assert_not_called() + + +def test_search_by_full_text_queries_tidb_fts_and_scores(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + vector._client_config = _config(tidb_module).model_copy(update={"enable_fulltext_search": True}) + session.execute.return_value = [ + ('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.8), + ('{"doc_id":"id-2","document_id":"d-2"}', "text-2", 0.6), + ] + + docs = vector.search_by_full_text( + "search query", + top_k=2, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + assert docs[0].metadata["score"] == pytest.approx(0.8) + assert docs[1].metadata["score"] == pytest.approx(0.6) + sql = str(session.execute.call_args.args[0]) + params = session.execute.call_args.kwargs["params"] + assert "FTS_MATCH_WORD(text, :query)" in sql + assert "meta->>'$.document_id' in ('d-1', 'd-2')" in sql + assert params == {"query": "search query", "score_threshold": 0.5, "top_k": 2} # 2. text_exists returns True when ids found @@ -428,6 +491,7 @@ def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeyp monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_USER", "root") monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PASSWORD", "secret") monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_DATABASE", "dify") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH", True) monkeypatch.setattr(tidb_module.dify_config, "APPLICATION_NAME", "dify-app") with patch.object(tidb_module, "TiDBVector", return_value="vector") as vector_cls: @@ -438,4 +502,5 @@ def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeyp assert result_2 == "vector" assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert vector_cls.call_args_list[0].kwargs["config"].enable_fulltext_search is True assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 76a0955898754e..0f6a3978a734a6 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -29,11 +29,14 @@ DatasetRetrievalSettingApi, DatasetRetrievalSettingMockApi, DatasetUseCheckApi, + _get_retrieval_methods_by_vector_type, ) from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.provider_manager import ProviderManager +from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.storage.storage_type import StorageType from models.account import Account, TenantAccountRole from models.dataset import Dataset, DatasetQuery, Document @@ -1989,6 +1992,28 @@ def test_get_success(self, app: Flask): assert "retrieval_method" in response + def test_tidb_vector_returns_semantic_only_when_fulltext_disabled(self): + with patch( + "controllers.console.datasets.datasets.dify_config.TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH", + False, + ): + response = _get_retrieval_methods_by_vector_type(VectorType.TIDB_VECTOR) + + assert response["retrieval_method"] == [RetrievalMethod.SEMANTIC_SEARCH.value] + + def test_tidb_vector_returns_full_methods_when_fulltext_enabled(self): + with patch( + "controllers.console.datasets.datasets.dify_config.TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH", + True, + ): + response = _get_retrieval_methods_by_vector_type(VectorType.TIDB_VECTOR) + + assert response["retrieval_method"] == [ + RetrievalMethod.SEMANTIC_SEARCH.value, + RetrievalMethod.FULL_TEXT_SEARCH.value, + RetrievalMethod.HYBRID_SEARCH.value, + ] + class TestDatasetRetrievalSettingMockApi: def test_get_success(self, app: Flask): diff --git a/docker/envs/core-services/shared.env.example b/docker/envs/core-services/shared.env.example index 391dba2e21adce..ad879de1fb54e1 100644 --- a/docker/envs/core-services/shared.env.example +++ b/docker/envs/core-services/shared.env.example @@ -394,6 +394,7 @@ TIDB_VECTOR_HOST=tidb TIDB_VECTOR_PORT=4000 TIDB_VECTOR_USER= TIDB_VECTOR_PASSWORD= +TIDB_VECTOR_ENABLE_FULLTEXT_SEARCH=false TIDB_ON_QDRANT_CLIENT_TIMEOUT=20 TIDB_ON_QDRANT_GRPC_ENABLED=false TIDB_ON_QDRANT_GRPC_PORT=6334