diff --git a/frontend/app/api/queries/useGetFiltersSearchQuery.ts b/frontend/app/api/queries/useGetFiltersSearchQuery.ts index 26c110f53..f85f948fc 100644 --- a/frontend/app/api/queries/useGetFiltersSearchQuery.ts +++ b/frontend/app/api/queries/useGetFiltersSearchQuery.ts @@ -12,6 +12,7 @@ export interface KnowledgeFilter { owner: string; created_at: string; updated_at: string; + active_source_count?: number; } export const useGetFiltersSearchQuery = ( diff --git a/frontend/components/knowledge-filter-list.tsx b/frontend/components/knowledge-filter-list.tsx index 455cf399d..923d70519 100644 --- a/frontend/components/knowledge-filter-list.tsx +++ b/frontend/components/knowledge-filter-list.tsx @@ -155,7 +155,8 @@ export function KnowledgeFilterList({ const dataSources = parseQueryData(filter.query_data) .filters.data_sources; if (dataSources[0] === "*") return "All sources"; - const count = dataSources.length; + const count = + filter.active_source_count ?? dataSources.length; return `${count} ${ count === 1 ? "source" : "sources" }`; diff --git a/frontend/components/knowledge-filter-panel.tsx b/frontend/components/knowledge-filter-panel.tsx index 03957e997..149bbc7bb 100644 --- a/frontend/components/knowledge-filter-panel.tsx +++ b/frontend/components/knowledge-filter-panel.tsx @@ -356,7 +356,13 @@ export function KnowledgeFilterPanel() {
+ sourceOptions.some((option) => option.value === source), + ) + } onValueChange={(values) => handleFilterChange("data_sources", values) } diff --git a/src/services/knowledge_filter_service.py b/src/services/knowledge_filter_service.py index 065e97159..a1a827a08 100644 --- a/src/services/knowledge_filter_service.py +++ b/src/services/knowledge_filter_service.py @@ -7,6 +7,71 @@ class KnowledgeFilterService: def __init__(self, session_manager=None): self.session_manager = session_manager + def _user_client(self, user_id: str = None, jwt_token: str = None): + return self.session_manager.get_user_opensearch_client(user_id, jwt_token) + + def _write_client(self, user_id: str = None, jwt_token: str = None): + # OpenSearch rejects write requests on indices protected by filter-level + # DLS. The app enforces ownership/visibility with the scoped client, then + # performs trusted writes with the admin client. + from config.settings import clients + + if clients.opensearch is None: + raise RuntimeError("Backend OpenSearch write client is unavailable") + return clients.opensearch + + async def _attach_active_source_counts(self, filters: list[dict[str, Any]]) -> None: + """Annotate each filter with active_source_count (mutates filters in place). + + Counts how many of each filter's configured data sources still have indexed + documents, via a single batched terms aggregation against the documents index + using the admin client (so the count is the same for every viewer of a shared + filter, not DLS-scoped per user). Filters scoped to "*" are skipped. + """ + try: + import json + + from config.settings import clients, get_index_name + from utils.logging_config import get_logger + from utils.opensearch_queries import build_existing_filenames_agg_body + + data_sources_by_filter = [] + all_filenames = set() + for knowledge_filter in filters: + try: + data_sources = json.loads(knowledge_filter.get("query_data") or "{}").get("filters", {}).get( + "data_sources" + ) + except Exception: + data_sources_by_filter.append(None) + continue + + if not data_sources or data_sources == ["*"]: + data_sources_by_filter.append(None) + continue + data_sources_by_filter.append(data_sources) + all_filenames.update(data_sources) + + if not all_filenames or clients.opensearch is None: + return + + existence_result = await clients.opensearch.search( + index=get_index_name(), + body=build_existing_filenames_agg_body(list(all_filenames)), + ) + existing_filenames = { + bucket["key"] + for bucket in existence_result["aggregations"]["filenames"]["buckets"] + } + + for knowledge_filter, data_sources in zip(filters, data_sources_by_filter): + if data_sources: + knowledge_filter["active_source_count"] = sum( + 1 for source in data_sources if source in existing_filenames + ) + except Exception: + get_logger(__name__).warning("active_source_count computation failed", exc_info=True) + async def create_knowledge_filter( self, filter_doc: Dict[str, Any], user_id: str = None, jwt_token: str = None ) -> Dict[str, Any]: @@ -102,6 +167,8 @@ async def search_knowledge_filters( knowledge_filter["score"] = hit.get("_score") filters.append(knowledge_filter) + await self._attach_active_source_counts(filters) + return {"success": True, "filters": filters} except Exception as e: diff --git a/src/utils/opensearch_queries.py b/src/utils/opensearch_queries.py index eadafd376..d6761f75e 100644 --- a/src/utils/opensearch_queries.py +++ b/src/utils/opensearch_queries.py @@ -30,16 +30,28 @@ def build_filename_search_body(filename: str, size: int = 1, source: Union[bool, size: Number of results to return (default: 1) source: Whether to include source fields, or list of specific fields to include (default: False) + Returns: + A dict containing the complete OpenSearch search body + """ + return {"query": build_filename_query(filename), "size": size, "_source": source} + + +def build_existing_filenames_agg_body(filenames: list[str]) -> dict: + """ + build a search body for checking which of the given filenames currently have indexed chunks + + Args: + filenames: Filenames to check for existence + Returns: A dict containing the complete OpenSearch search body """ return { - "query": build_filename_query(filename), - "size": size, - "_source": source + "query": {"terms": {"filename": filenames}}, + "size": 0, + "aggs": {"filenames": {"terms": {"field": "filename", "size": len(filenames)}}}, } - def build_filename_delete_body(filename: str) -> dict: """ Build a delete-by-query body for removing all documents with a filename. diff --git a/tests/unit/test_knowledge_filter_service.py b/tests/unit/test_knowledge_filter_service.py new file mode 100644 index 000000000..1181e1b3a --- /dev/null +++ b/tests/unit/test_knowledge_filter_service.py @@ -0,0 +1,95 @@ +from types import SimpleNamespace + +import json +import pytest + +from services.knowledge_filter_service import ( + KNOWLEDGE_FILTERS_INDEX_NAME, + KnowledgeFilterService, +) + +class _Indices: + async def refresh(self, index): + return {"acknowledged": True, "index": index} + + +def _filter(filter_id, data_sources=None, query_data=None): + if query_data is None: + query_data = json.dumps({"filters": {"data_sources": data_sources}}) if data_sources else "{}" + return {"id": filter_id, "name": filter_id, "query_data": query_data} + + +def _setup_search(monkeypatch, filters, existing_filenames): + """Service whose user client returns `filters` from search, and whose + admin client's existence-check aggregation returns `existing_filenames`. + """ + + async def user_search(*, index, body): + return {"hits": {"hits": [{"_source": f, "_score": 1.0} for f in filters]}} + + admin_client = SimpleNamespace(search_calls=[]) + + async def admin_search(*, index, body): + admin_client.search_calls.append(body) + return { + "aggregations": { + "filenames": {"buckets": [{"key": name} for name in existing_filenames]} + } + } + + admin_client.search = admin_search + + class SessionManager: + def get_user_opensearch_client(self, user_id, jwt_token): + return SimpleNamespace(search=user_search) + + monkeypatch.setattr("config.settings.clients", SimpleNamespace(opensearch=admin_client)) + monkeypatch.setattr("config.settings.get_index_name", lambda: "documents") + + return KnowledgeFilterService(SessionManager()), admin_client + + +@pytest.mark.asyncio +async def test_search_knowledge_filters_active_source_count_zero_when_document_deleted( + monkeypatch, +): + filters = [_filter("filter-1", data_sources=["README.md"])] + service, admin_client = _setup_search(monkeypatch, filters, existing_filenames=set()) + + result = await service.search_knowledge_filters("", user_id="user-1", jwt_token="token") + + assert result["success"] is True + assert result["filters"][0]["active_source_count"] == 0 + assert len(admin_client.search_calls) == 1 + + +@pytest.mark.asyncio +async def test_search_knowledge_filters_malformed_query_data_fails_silently(monkeypatch): + filters = [ + _filter("filter-1", query_data="not json"), + _filter("filter-2", data_sources=["a.md"]), + ] + service, _ = _setup_search(monkeypatch, filters, existing_filenames={"a.md"}) + + result = await service.search_knowledge_filters("", user_id="user-1", jwt_token="token") + + assert result["success"] is True + assert len(result["filters"]) == 2 + assert "active_source_count" not in result["filters"][0] # malformed filter + assert result["filters"][1]["active_source_count"] == 1 # valid filter + + +@pytest.mark.asyncio +async def test_search_knowledge_filters_dedups_shared_filenames_in_one_query(monkeypatch): + filters = [ + _filter("filter-1", data_sources=["shared.pdf"]), + _filter("filter-2", data_sources=["shared.pdf"]), + ] + service, admin_client = _setup_search(monkeypatch, filters, existing_filenames={"shared.pdf"}) + + result = await service.search_knowledge_filters("", user_id="user-1", jwt_token="token") + + assert len(admin_client.search_calls) == 1 + assert admin_client.search_calls[0]["query"]["terms"]["filename"] == ["shared.pdf"] + assert result["filters"][0]["active_source_count"] == 1 + assert result["filters"][1]["active_source_count"] == 1