diff --git a/services/ai/app/core/config.py b/services/ai/app/core/config.py index f76550d..01c9565 100644 --- a/services/ai/app/core/config.py +++ b/services/ai/app/core/config.py @@ -4,9 +4,9 @@ from pathlib import Path from typing import Any, Optional +import bcrypt as _bcrypt from dotenv import load_dotenv from jose import JWTError, jwt -from passlib.context import CryptContext from pydantic import BaseModel # Load environment variables from .env file @@ -83,10 +83,6 @@ class Settings: settings = Settings() -# Password hashing context -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - - class TokenPayload(BaseModel): """JWT token payload structure.""" @@ -98,31 +94,12 @@ class TokenPayload(BaseModel): type: str # 'access' or 'refresh' -def verify_password(plain_password: str, hashed_password: str) -> bool: - """ - Verify a plain password against a hashed password. - - Args: - plain_password: The plain text password to verify. - hashed_password: The hashed password to compare against. - - Returns: - True if the password matches, False otherwise. - """ - return pwd_context.verify(plain_password, hashed_password) - - def get_password_hash(password: str) -> str: - """ - Hash a password using bcrypt. + return _bcrypt.hashpw(password.encode("utf-8"), _bcrypt.gensalt()).decode("utf-8") - Args: - password: The plain text password to hash. - Returns: - The hashed password. - """ - return pwd_context.hash(password) +def verify_password(plain_password: str, hashed_password: str) -> bool: + return _bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8")) def create_access_token( diff --git a/services/ai/app/services/chat_service.py b/services/ai/app/services/chat_service.py index 114a1ed..6bd4cfc 100644 --- a/services/ai/app/services/chat_service.py +++ b/services/ai/app/services/chat_service.py @@ -1,178 +1,24 @@ -"""Chat service — hybrid search (QVAC dense + BM25 sparse) with parent-child context.""" +"""Chat service — hybrid RAG pipeline: QVAC dense + BM25 sparse + reranker + parent context.""" import asyncio -import json import logging import os -import pickle from dataclasses import dataclass, field -from pathlib import Path from typing import List import httpx +from app.schemas.evidence_pack import CitationAnchor, EvidenceChunk + logger = logging.getLogger(__name__) _QVAC_SERVICE_URL = os.getenv("QVAC_SERVICE_URL", "") -# RAG_RETRIEVE_K: total chunks fetched for hybrid search (dense + sparse pool). -# RAG_TOP_K: chunks passed to the LLM after reranking (context window budget). +# RAG_RETRIEVE_K: total candidates fetched from dense + sparse pool. +# RAG_TOP_K: chunks handed to the LLM after reranking (context window budget). _TOP_K_RETRIEVE = int(os.getenv("RAG_RETRIEVE_K", "20")) _TOP_K_GENERATE = int(os.getenv("RAG_TOP_K", "5")) -# Directory where BM25 corpus.json and bm25.pkl are stored (same as QVAC_INGEST_DIR). -_QVAC_INGEST_DIR = Path(os.getenv("QVAC_INGEST_DIR", "")) - _client = httpx.AsyncClient(base_url=_QVAC_SERVICE_URL, timeout=60.0) -_reranker = None - - -def _get_reranker(): - global _reranker - if _reranker is None: - try: - from flashrank import Ranker - _reranker = Ranker(model_name="ms-marco-MiniLM-L-6-v2", cache_dir="/tmp/flashrank") - logger.info("FlashRank reranker loaded (ms-marco-MiniLM-L-6-v2)") - except Exception as exc: - logger.warning("FlashRank unavailable — skipping reranking: %s", exc) - return _reranker - - -def _rerank_sources(question: str, sources: list) -> list: - """Rerank with FlashRank cross-encoder; returns top _TOP_K_GENERATE results.""" - if len(sources) <= 1: - return sources[:_TOP_K_GENERATE] - reranker = _get_reranker() - if reranker is None: - return sources[:_TOP_K_GENERATE] - try: - from flashrank import RerankRequest - passages = [ - {"id": i, "text": s.get("content") or s.get("snippet", "")} - for i, s in enumerate(sources) - ] - reranked = reranker.rerank(RerankRequest(query=question, passages=passages)) - top_ids = [r["id"] for r in reranked[:_TOP_K_GENERATE]] - return [sources[i] for i in top_ids] - except Exception as exc: - logger.warning("FlashRank reranking failed, falling back to dense order: %s", exc) - return sources[:_TOP_K_GENERATE] - - -# --------------------------------------------------------------------------- -# BM25 helpers -# --------------------------------------------------------------------------- - -def _bm25_search(question: str, course_id: str, top_k: int = 20) -> list[dict]: - """Query the BM25 sparse index for the course; returns [{chunk_id, score}].""" - if not _QVAC_INGEST_DIR: - return [] - bm25_path = _QVAC_INGEST_DIR / f"{course_id}_bm25.pkl" - if not bm25_path.exists(): - return [] - try: - with bm25_path.open("rb") as f: - data = pickle.load(f) - bm25 = data["bm25"] - ids = data["ids"] - tokens = question.lower().split() - scores = bm25.get_scores(tokens) - ranked = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k] - return [{"chunk_id": ids[i], "score": float(scores[i])} for i in ranked if scores[i] > 0] - except Exception as exc: - logger.warning("BM25 search failed for course '%s': %s", course_id, exc) - return [] - - -def _rrf_merge( - dense_chunks: list[dict], - bm25_results: list[dict], - top_n: int = 20, - k: int = 60, -) -> list[str]: - """Reciprocal Rank Fusion over dense (QVAC) and sparse (BM25) results. - - Dense chunks are keyed by their original chunk_id field. - Returns an ordered list of chunk_ids. - """ - rrf: dict[str, float] = {} - - for rank, chunk in enumerate(dense_chunks): - cid = chunk.get("chunk_id", "") - if cid: - rrf[cid] = rrf.get(cid, 0.0) + 1.0 / (rank + k) - - for rank, item in enumerate(bm25_results): - cid = item["chunk_id"] - rrf[cid] = rrf.get(cid, 0.0) + 1.0 / (rank + k) - - return sorted(rrf.keys(), key=lambda x: rrf[x], reverse=True)[:top_n] - - -def _load_corpus(course_id: str) -> dict: - """Load the BM25 corpus JSON for the course (lazy, per-request).""" - if not _QVAC_INGEST_DIR: - return {} - corpus_path = _QVAC_INGEST_DIR / f"{course_id}_corpus.json" - if not corpus_path.exists(): - return {} - try: - with corpus_path.open(encoding="utf-8") as f: - return json.load(f) - except (OSError, json.JSONDecodeError): - return {} - - -def _resolve_merged( - merged_ids: list[str], - dense_registry: dict[str, dict], - course_id: str, -) -> list[dict]: - """Map RRF-merged chunk_ids to full chunk info. - - Uses QVAC dense results first; falls back to corpus.json for BM25-only chunks. - """ - bm25_only = [cid for cid in merged_ids if cid not in dense_registry] - corpus = _load_corpus(course_id) if bm25_only else {} - - result = [] - for cid in merged_ids: - if cid in dense_registry: - result.append(dense_registry[cid]) - elif cid in corpus: - entry = corpus[cid] - result.append({ - "chunk_id": cid, - "content": entry.get("text", ""), - "score": 0.0, - "label": entry.get("label", ""), - "page": entry.get("page", 0), - "slide": 0, - "section": entry.get("section", ""), - "doc_id": entry.get("doc_id", ""), - "parent_id": entry.get("parent_id", ""), - }) - return result - - -# --------------------------------------------------------------------------- -# Parent lookup -# --------------------------------------------------------------------------- - -def _load_parents_from_db(parent_ids: list[str]) -> dict[str, dict]: - """Sync DB query: returns {parent_id: {text, label}}.""" - if not parent_ids: - return {} - try: - from app.db.session import get_db_context # noqa: PLC0415 - from app.db.models import ChunkParent # noqa: PLC0415 - with get_db_context() as db: - rows = db.query(ChunkParent).filter(ChunkParent.id.in_(parent_ids)).all() - return {r.id: {"text": r.text, "label": r.citation_label} for r in rows} - except Exception as exc: - logger.warning("Parent DB lookup failed: %s", exc) - return {} - # --------------------------------------------------------------------------- # Data classes @@ -196,6 +42,28 @@ class ChatResult: retrieval_used: bool = False +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _qvac_dict_to_chunk(d: dict) -> EvidenceChunk: + """Convert a QVAC /retrieve response dict to an EvidenceChunk.""" + return EvidenceChunk( + chunk_id=d.get("chunk_id", ""), + text=d.get("content", "") or d.get("text", ""), + score=float(d.get("score", 0.0)), + anchor=CitationAnchor( + doc_id=d.get("doc_id", ""), + doc_name=d.get("label", ""), + section=d.get("section") or None, + page=int(d["page"]) if d.get("page") else None, + slide=int(d["slide"]) if d.get("slide") else None, + chunk_id=d.get("chunk_id", ""), + chunk_type="paragraph", + ), + ) + + # --------------------------------------------------------------------------- # ChromaDB fallback # --------------------------------------------------------------------------- @@ -229,17 +97,19 @@ def _chroma_chat_result(question: str, course_id: str) -> ChatResult: # --------------------------------------------------------------------------- async def answer(question: str, course_id: str) -> ChatResult: - """Hybrid RAG answer: dense (QVAC) + sparse (BM25) → RRF → FlashRank → parent context → LLM. + """Hybrid RAG answer: dense (QVAC) + sparse (BM25) → RRF → rerank → parent context → LLM. Flow: - 1. /retrieve topK=20 dense chunks from QVAC + 1. /retrieve — top-20 dense chunks from QVAC 2. BM25 sparse search on local index - 3. RRF merge → unified top-20 - 4. FlashRank cross-encoder rerank → top-5 - 5. DB lookup of parent texts for top-5 child chunks - 6. /generate with parent contexts → LLM answer - Falls back to ChromaDB if QVAC is unavailable. + 3. RRF fusion → unified top-20 + 4. Cross-encoder rerank (FlashRank) → top-5 + 5. Parent context expansion (child text → 1200-word parent block) + 6. /generate — LLM answer from parent contexts + Falls back to ChromaDB when QVAC is unavailable. """ + from app.services import hybrid_search, reranker, parent_expansion # noqa: PLC0415 + # 1. Dense retrieval try: resp = await _client.post( @@ -247,51 +117,46 @@ async def answer(question: str, course_id: str) -> ChatResult: json={"question": question, "workspace": course_id, "topK": _TOP_K_RETRIEVE}, ) resp.raise_for_status() - dense_data = resp.json() - dense_chunks: list[dict] = dense_data.get("chunks", []) + dense_dicts: list[dict] = resp.json().get("chunks", []) except httpx.HTTPError as exc: logger.warning("QVAC /retrieve unavailable (%s) — trying ChromaDB fallback", exc) return _chroma_chat_result(question, course_id) - if not dense_chunks: + if not dense_dicts: logger.info("QVAC returned 0 chunks for course '%s', trying ChromaDB fallback", course_id) fallback = _chroma_chat_result(question, course_id) if fallback.citations: return fallback - # 2. BM25 sparse retrieval - bm25_results = _bm25_search(question, course_id, top_k=_TOP_K_RETRIEVE) + # 2. Convert QVAC dicts → EvidenceChunk for unified processing + dense_chunks = [_qvac_dict_to_chunk(d) for d in dense_dicts if d.get("chunk_id")] - # 3. RRF merge - dense_registry = {c["chunk_id"]: c for c in dense_chunks if c.get("chunk_id")} - if bm25_results: - merged_ids = _rrf_merge(dense_chunks, bm25_results, top_n=_TOP_K_RETRIEVE) - merged_chunks = _resolve_merged(merged_ids, dense_registry, course_id) - else: - merged_chunks = dense_chunks[:_TOP_K_RETRIEVE] + # 3. BM25 sparse retrieval + bm25_hits = hybrid_search.bm25_search(question, course_id, top_k=_TOP_K_RETRIEVE) - # 4. FlashRank rerank → top-5 - reranked = _rerank_sources(question, merged_chunks) + # 4. RRF fusion — falls back to dense-only when BM25 index is absent + if bm25_hits: + index_data = hybrid_search.load_bm25_index(course_id) + corpus = index_data[2] if index_data else {} + merged = hybrid_search.rrf_fuse(dense_chunks, bm25_hits, corpus, top_k=_TOP_K_RETRIEVE) + else: + logger.debug("BM25 index absent for course '%s' — dense-only retrieval", course_id) + merged = dense_chunks[:_TOP_K_RETRIEVE] - # 5. Parent text lookup (async-safe: sync DB call via thread) - parent_ids = list({c.get("parent_id", "") for c in reranked if c.get("parent_id")}) - parent_map: dict[str, dict] = await asyncio.to_thread(_load_parents_from_db, parent_ids) + # 5. Rerank with FlashRank cross-encoder → keep top _TOP_K_GENERATE + reranked_all = reranker.rerank(question, merged) + reranked = reranked_all[:_TOP_K_GENERATE] - # 6. Build LLM context: use parent text when available (richer context window) - context_blocks: list[dict] = [] - seen_parents: set[str] = set() - for c in reranked: - pid = c.get("parent_id", "") - if pid and pid in parent_map and pid not in seen_parents: - seen_parents.add(pid) - context_blocks.append({"label": parent_map[pid]["label"], "text": parent_map[pid]["text"]}) - elif not pid or pid not in parent_map: - context_blocks.append({"label": c.get("label", ""), "text": c.get("content", "")}) + # 6. Expand child chunks → parent context (richer LLM context window) + context_chunks = parent_expansion.expand_to_parents(reranked) - if not context_blocks: - context_blocks = [{"label": c.get("label", ""), "text": c.get("content", "")} for c in reranked] + # 7. Build context blocks for LLM generation + context_blocks = [ + {"label": c.anchor.doc_name, "text": c.text} + for c in context_chunks + ] - # 7. LLM generation with parent contexts + # 8. LLM generation answer_text = "" try: gen_resp = await _client.post( @@ -304,16 +169,16 @@ async def answer(question: str, course_id: str) -> ChatResult: logger.warning("QVAC /generate failed (%s) — returning first context block", exc) answer_text = context_blocks[0]["text"] if context_blocks else "Risposta non disponibile." - # 8. Citations (child-level for precise source attribution) + # 9. Citations from child chunks (preserves page/slide precision) citations = [ Citation( - snippet=(c.get("content") or c.get("snippet", ""))[:200], - score=c.get("score", 0.0), - label=c.get("label", ""), - page=c.get("page", 0), - slide=c.get("slide", 0), - section=c.get("section", ""), - doc_id=c.get("doc_id", ""), + snippet=c.text[:200], + score=c.score, + label=c.anchor.doc_name, + page=c.anchor.page or 0, + slide=c.anchor.slide or 0, + section=c.anchor.section or "", + doc_id=c.anchor.doc_id, ) for c in reranked ] diff --git a/services/ai/app/workers/pipeline.py b/services/ai/app/workers/pipeline.py index 144b6ce..17d0e33 100644 --- a/services/ai/app/workers/pipeline.py +++ b/services/ai/app/workers/pipeline.py @@ -32,16 +32,45 @@ from app.db.session import get_db_context # noqa: E402 from app.repositories import document_repo # noqa: E402 +# --------------------------------------------------------------------------- +# sys.modules aliasing — dual-import guard +# --------------------------------------------------------------------------- +# Register 'services.ai.app.*' as aliases for 'app.*' in sys.modules. +# This ensures that classes imported via either path are the same objects, +# preventing silent Pydantic isinstance failures when the worker is invoked +# from the project root instead of from services/ai/. +import sys as _sys +import types as _types + + +def _register_module_aliases() -> None: + canonical = "app" + alias_root = "services.ai.app" + + for ns_name in ("services", "services.ai", "services.ai.app"): + if ns_name not in _sys.modules: + ns = _types.ModuleType(ns_name) + ns.__path__ = [] # type: ignore[attr-defined] + _sys.modules[ns_name] = ns + + for name in list(_sys.modules): + if name == canonical or name.startswith(canonical + "."): + long_name = alias_root + name[len(canonical):] + _sys.modules.setdefault(long_name, _sys.modules[name]) + + +_register_module_aliases() + # --------------------------------------------------------------------------- # Chunking parameters # --------------------------------------------------------------------------- -_PARENT_WORDS = 1200 # parent chunk: contesto LLM (≈ 1500 token) -_CHILD_WORDS = 150 # child chunk: unità di retrieval (≈ 200 token) -_CHILD_OVERLAP = 30 # overlap tra child chunk consecutivi (parole) -_MAX_WORDS = 400 # legacy: usato solo da chunk_pages() (non più chiamata da run()) -_OVERLAP_WORDS = 50 # legacy: overlap usato da chunk_pages() -_MIN_WORDS = 25 # soglia paragrafi: chunk più corti vengono scartati -_MIN_WORDS_TABLE = 4 # soglia tabelle: basta una riga dati (celle corte) +_PARENT_WORDS = 1200 # parent chunk: LLM context window (≈ 1500 tokens) +_CHILD_WORDS = 150 # child chunk: retrieval unit (≈ 200 tokens) +_CHILD_OVERLAP = 30 # overlap between consecutive child chunks (words) +_MAX_WORDS = 400 # legacy: only used by chunk_pages() (no longer called by run()) +_OVERLAP_WORDS = 50 # legacy: overlap used by chunk_pages() +_MIN_WORDS = 25 # paragraph threshold: shorter chunks are discarded +_MIN_WORDS_TABLE = 4 # table threshold: one data row is enough (cells are short) # --------------------------------------------------------------------------- # Helpers diff --git a/services/ai/tests/unit/test_chat_service.py b/services/ai/tests/unit/test_chat_service.py index d0ec50c..15e2980 100644 --- a/services/ai/tests/unit/test_chat_service.py +++ b/services/ai/tests/unit/test_chat_service.py @@ -1,213 +1,261 @@ -"""Unit tests for app.services.chat_service — pure-logic helpers. +"""Unit tests for app.services.chat_service. -No network calls or DB connections are needed for these tests. -BM25 tests are skipped automatically if rank_bm25 is not installed. +Covers _qvac_dict_to_chunk() and the async answer() function. +All network calls, hybrid_search, reranker, and parent_expansion are mocked. """ -import json -import pickle -from pathlib import Path -from unittest.mock import patch - import pytest +from unittest.mock import AsyncMock, MagicMock, patch -from app.services.chat_service import _rrf_merge, _resolve_merged +from app.services.chat_service import _qvac_dict_to_chunk, ChatResult, Citation +from app.schemas.evidence_pack import CitationAnchor, EvidenceChunk # --------------------------------------------------------------------------- -# _rrf_merge +# Helpers # --------------------------------------------------------------------------- -@pytest.mark.unit -def test_rrf_merge_empty_inputs_return_empty(): - assert _rrf_merge([], [], top_n=5) == [] +def _make_chunk(chunk_id: str = "DOC1_p0000_c0000", text: str = "Bitcoin text.", + score: float = 0.9) -> EvidenceChunk: + return EvidenceChunk( + chunk_id=chunk_id, + text=text, + score=score, + anchor=CitationAnchor( + doc_id="DOC1", + doc_name="Bitcoin Whitepaper", + section="Intro", + page=1, + slide=None, + chunk_id=chunk_id, + chunk_type="paragraph", + ), + ) + + +def _make_qvac_dict(chunk_id: str = "DOC1_p0000_c0000") -> dict: + return { + "chunk_id": chunk_id, + "content": "Bitcoin uses UTXO.", + "score": 0.85, + "label": "Bitcoin Whitepaper", + "page": 3, + "slide": 0, + "section": "Transactions", + "doc_id": "DOC1", + "parent_id": "DOC1_p0000", + } + + +def _mock_httpx_response(json_data: dict, status_code: int = 200): + resp = MagicMock() + resp.json.return_value = json_data + resp.status_code = status_code + resp.raise_for_status = MagicMock() + return resp + + +def _mock_httpx_error(): + import httpx + resp = MagicMock() + resp.raise_for_status.side_effect = httpx.ConnectError("connection refused") + return resp +# --------------------------------------------------------------------------- +# _qvac_dict_to_chunk +# --------------------------------------------------------------------------- + @pytest.mark.unit -def test_rrf_merge_dense_only_returns_all_ids(): - dense = [{"chunk_id": "c1"}, {"chunk_id": "c2"}] - result = _rrf_merge(dense, [], top_n=10) - assert "c1" in result - assert "c2" in result +def test_qvac_dict_to_chunk_maps_fields(): + d = _make_qvac_dict() + chunk = _qvac_dict_to_chunk(d) + assert chunk.chunk_id == "DOC1_p0000_c0000" + assert chunk.text == "Bitcoin uses UTXO." + assert chunk.score == 0.85 + assert chunk.anchor.doc_name == "Bitcoin Whitepaper" + assert chunk.anchor.page == 3 + assert chunk.anchor.doc_id == "DOC1" + assert chunk.anchor.section == "Transactions" @pytest.mark.unit -def test_rrf_merge_bm25_only_preserves_rank_order(): - bm25 = [ - {"chunk_id": "b1", "score": 0.9}, - {"chunk_id": "b2", "score": 0.5}, - ] - result = _rrf_merge([], bm25, top_n=10) - assert result.index("b1") < result.index("b2") +def test_qvac_dict_to_chunk_slide_zero_becomes_none(): + d = {**_make_qvac_dict(), "slide": 0} + chunk = _qvac_dict_to_chunk(d) + assert chunk.anchor.slide is None @pytest.mark.unit -def test_rrf_merge_shared_id_ranks_first(): - dense = [{"chunk_id": "shared"}, {"chunk_id": "dense_only"}] - bm25 = [{"chunk_id": "shared", "score": 0.9}, {"chunk_id": "bm25_only", "score": 0.5}] - result = _rrf_merge(dense, bm25, top_n=10) - assert result[0] == "shared" +def test_qvac_dict_to_chunk_page_zero_becomes_none(): + d = {**_make_qvac_dict(), "page": 0} + chunk = _qvac_dict_to_chunk(d) + assert chunk.anchor.page is None @pytest.mark.unit -def test_rrf_merge_respects_top_n(): - dense = [{"chunk_id": f"d{i}"} for i in range(20)] - bm25 = [{"chunk_id": f"b{i}", "score": 0.5} for i in range(20)] - result = _rrf_merge(dense, bm25, top_n=7) - assert len(result) == 7 +def test_qvac_dict_to_chunk_empty_section_becomes_none(): + d = {**_make_qvac_dict(), "section": ""} + chunk = _qvac_dict_to_chunk(d) + assert chunk.anchor.section is None @pytest.mark.unit -def test_rrf_merge_skips_empty_chunk_id(): - dense = [{"chunk_id": ""}, {"chunk_id": "c1"}, {}] - result = _rrf_merge(dense, [], top_n=5) - assert "" not in result - assert "c1" in result +def test_qvac_dict_to_chunk_uses_content_key(): + d = {"chunk_id": "c1", "content": "Dense content.", "text": "Should not use this.", + "score": 0.5, "label": "", "page": 0, "slide": 0, "section": "", "doc_id": ""} + chunk = _qvac_dict_to_chunk(d) + assert chunk.text == "Dense content." @pytest.mark.unit -def test_rrf_merge_deduplicates_ids(): - dense = [{"chunk_id": "c1"}, {"chunk_id": "c1"}] # duplicate - bm25 = [{"chunk_id": "c1", "score": 0.8}] - result = _rrf_merge(dense, bm25, top_n=10) - assert result.count("c1") == 1 +def test_qvac_dict_to_chunk_falls_back_to_text_key(): + d = {"chunk_id": "c1", "content": "", "text": "Fallback text.", + "score": 0.5, "label": "", "page": 0, "slide": 0, "section": "", "doc_id": ""} + chunk = _qvac_dict_to_chunk(d) + assert chunk.text == "Fallback text." # --------------------------------------------------------------------------- -# _resolve_merged +# answer() — happy path with hybrid search # --------------------------------------------------------------------------- +@pytest.mark.asyncio @pytest.mark.unit -def test_resolve_merged_uses_dense_registry(): - registry = { - "c1": {"chunk_id": "c1", "content": "Dense C1", "score": 0.9}, - "c2": {"chunk_id": "c2", "content": "Dense C2", "score": 0.8}, - } - result = _resolve_merged(["c1", "c2"], registry, "COURSE1") - assert len(result) == 2 - assert result[0]["content"] == "Dense C1" - assert result[1]["content"] == "Dense C2" - - +async def test_answer_happy_path_returns_chat_result(): + chunk_dict = _make_qvac_dict() + ev_chunk = _make_chunk() + bm25_hits = [("DOC1_p0000_c0000", 2.5)] + corpus = {"DOC1_p0000_c0000": {"text": "BM25 text", "doc_id": "DOC1"}} + + retrieve_resp = _mock_httpx_response({"chunks": [chunk_dict]}) + generate_resp = _mock_httpx_response({"answer": "Bitcoin is a P2P currency."}) + + with patch("app.services.chat_service._client") as mock_client, \ + patch("app.services.hybrid_search.bm25_search", return_value=bm25_hits), \ + patch("app.services.hybrid_search.load_bm25_index", return_value=(None, None, corpus)), \ + patch("app.services.hybrid_search.rrf_fuse", return_value=[ev_chunk]), \ + patch("app.services.reranker.rerank", return_value=[ev_chunk]), \ + patch("app.services.parent_expansion.expand_to_parents", return_value=[ev_chunk]): + + mock_client.post = AsyncMock(side_effect=[retrieve_resp, generate_resp]) + result = await __import__("app.services.chat_service", fromlist=["answer"]).answer( + "What is Bitcoin?", "COURSE1" + ) + + assert isinstance(result, ChatResult) + assert result.answer == "Bitcoin is a P2P currency." + assert result.retrieval_used is True + assert len(result.citations) == 1 + + +@pytest.mark.asyncio @pytest.mark.unit -def test_resolve_merged_preserves_merged_order(): - registry = { - "c1": {"chunk_id": "c1", "content": "C1"}, - "c2": {"chunk_id": "c2", "content": "C2"}, - "c3": {"chunk_id": "c3", "content": "C3"}, - } - result = _resolve_merged(["c3", "c1", "c2"], registry, "COURSE1") - assert [r["content"] for r in result] == ["C3", "C1", "C2"] +async def test_answer_dense_only_when_no_bm25(): + chunk_dict = _make_qvac_dict() + ev_chunk = _make_chunk() + retrieve_resp = _mock_httpx_response({"chunks": [chunk_dict]}) + generate_resp = _mock_httpx_response({"answer": "Mining secures the chain."}) -@pytest.mark.unit -def test_resolve_merged_falls_back_to_corpus_for_bm25_only(tmp_path): - corpus = { - "bm25_only": { - "text": "BM25-only content", - "label": "p. 5", - "page": 5, - "section": "Mining", - "doc_id": "DOC1", - "parent_id": "DOC1_p0000", - } - } - (tmp_path / "COURSE1_corpus.json").write_text(json.dumps(corpus)) + with patch("app.services.chat_service._client") as mock_client, \ + patch("app.services.hybrid_search.bm25_search", return_value=[]), \ + patch("app.services.hybrid_search.rrf_fuse") as mock_rrf, \ + patch("app.services.reranker.rerank", return_value=[ev_chunk]), \ + patch("app.services.parent_expansion.expand_to_parents", return_value=[ev_chunk]): - with patch("app.services.chat_service._QVAC_INGEST_DIR", tmp_path): - result = _resolve_merged(["bm25_only"], {}, "COURSE1") + mock_client.post = AsyncMock(side_effect=[retrieve_resp, generate_resp]) + from app.services.chat_service import answer + result = await answer("What is mining?", "COURSE1") - assert len(result) == 1 - assert result[0]["content"] == "BM25-only content" - assert result[0]["chunk_id"] == "bm25_only" - assert result[0]["page"] == 5 + # rrf_fuse must NOT be called when BM25 has no results + mock_rrf.assert_not_called() + assert result.answer == "Mining secures the chain." +@pytest.mark.asyncio @pytest.mark.unit -def test_resolve_merged_mixes_dense_and_corpus(tmp_path): - corpus = { - "bm25_id": { - "text": "Corpus text", - "label": "p. 2", - "page": 2, - "section": "", - "doc_id": "DOC1", - "parent_id": "", - } - } - (tmp_path / "COURSE1_corpus.json").write_text(json.dumps(corpus)) +async def test_answer_chroma_fallback_on_retrieve_error(): + import httpx - registry = {"dense_id": {"chunk_id": "dense_id", "content": "Dense text", "score": 0.9}} + chroma_result = ChatResult( + answer="Fallback answer.", + citations=[Citation(snippet="s", score=0.5, label="doc", page=1)], + retrieval_used=True, + ) - with patch("app.services.chat_service._QVAC_INGEST_DIR", tmp_path): - result = _resolve_merged(["dense_id", "bm25_id"], registry, "COURSE1") + with patch("app.services.chat_service._client") as mock_client, \ + patch("app.services.chat_service._chroma_chat_result", return_value=chroma_result): - assert len(result) == 2 - assert result[0]["content"] == "Dense text" - assert result[1]["content"] == "Corpus text" + mock_client.post = AsyncMock(side_effect=httpx.ConnectError("refused")) + from app.services.chat_service import answer + result = await answer("What is Bitcoin?", "COURSE1") + assert result.answer == "Fallback answer." + assert result.retrieval_used is True + +@pytest.mark.asyncio @pytest.mark.unit -def test_resolve_merged_skips_unknown_ids(tmp_path): - (tmp_path / "COURSE1_corpus.json").write_text("{}") - with patch("app.services.chat_service._QVAC_INGEST_DIR", tmp_path): - result = _resolve_merged(["unknown_id"], {}, "COURSE1") - assert result == [] +async def test_answer_chroma_fallback_on_zero_chunks(): + chroma_result = ChatResult( + answer="Chroma answer.", + citations=[Citation(snippet="s", score=0.5, label="doc", page=1)], + retrieval_used=True, + ) + retrieve_resp = _mock_httpx_response({"chunks": []}) + with patch("app.services.chat_service._client") as mock_client, \ + patch("app.services.chat_service._chroma_chat_result", return_value=chroma_result): -# --------------------------------------------------------------------------- -# _bm25_search -# --------------------------------------------------------------------------- + mock_client.post = AsyncMock(return_value=retrieve_resp) + from app.services.chat_service import answer + result = await answer("What is Bitcoin?", "COURSE1") -@pytest.mark.unit -def test_bm25_search_returns_empty_for_empty_ingest_dir(): - from app.services.chat_service import _bm25_search - with patch("app.services.chat_service._QVAC_INGEST_DIR", Path("")): - result = _bm25_search("bitcoin", "COURSE1") - assert result == [] + assert result.answer == "Chroma answer." +@pytest.mark.asyncio @pytest.mark.unit -def test_bm25_search_returns_empty_when_pkl_missing(tmp_path): - from app.services.chat_service import _bm25_search - with patch("app.services.chat_service._QVAC_INGEST_DIR", tmp_path): - result = _bm25_search("bitcoin", "NO_SUCH_COURSE") - assert result == [] +async def test_answer_generate_failure_returns_first_context_block(): + import httpx + chunk_dict = _make_qvac_dict() + ev_chunk = _make_chunk(text="First context block text.") -@pytest.mark.unit -def test_bm25_search_returns_ranked_results(tmp_path): - pytest.importorskip("rank_bm25") - from rank_bm25 import BM25Okapi - from app.services.chat_service import _bm25_search + retrieve_resp = _mock_httpx_response({"chunks": [chunk_dict]}) + gen_error = httpx.ConnectError("refused") - ids = ["chunk_bitcoin", "chunk_mining"] - tokenized = [["bitcoin", "utxo", "transaction"], ["proof", "work", "mining", "hash"]] - bm25 = BM25Okapi(tokenized) - with (tmp_path / "COURSE1_bm25.pkl").open("wb") as f: - pickle.dump({"ids": ids, "bm25": bm25}, f) + with patch("app.services.chat_service._client") as mock_client, \ + patch("app.services.hybrid_search.bm25_search", return_value=[]), \ + patch("app.services.reranker.rerank", return_value=[ev_chunk]), \ + patch("app.services.parent_expansion.expand_to_parents", return_value=[ev_chunk]): - with patch("app.services.chat_service._QVAC_INGEST_DIR", tmp_path): - results = _bm25_search("bitcoin utxo transaction", "COURSE1", top_k=5) + mock_post = AsyncMock(side_effect=[retrieve_resp, gen_error]) + mock_client.post = mock_post + from app.services.chat_service import answer + result = await answer("What is Bitcoin?", "COURSE1") - assert len(results) > 0 - assert all("chunk_id" in r and "score" in r for r in results) - assert results[0]["chunk_id"] == "chunk_bitcoin" + assert result.answer == "First context block text." +@pytest.mark.asyncio @pytest.mark.unit -def test_bm25_search_zero_score_excluded(tmp_path): - pytest.importorskip("rank_bm25") - from rank_bm25 import BM25Okapi - from app.services.chat_service import _bm25_search - - ids = ["relevant", "irrelevant"] - tokenized = [["bitcoin", "utxo"], ["astronomy", "stars"]] - bm25 = BM25Okapi(tokenized) - with (tmp_path / "COURSE1_bm25.pkl").open("wb") as f: - pickle.dump({"ids": ids, "bm25": bm25}, f) - - with patch("app.services.chat_service._QVAC_INGEST_DIR", tmp_path): - results = _bm25_search("bitcoin", "COURSE1", top_k=10) - - chunk_ids = [r["chunk_id"] for r in results] - assert "irrelevant" not in chunk_ids +async def test_answer_citations_use_child_chunks(): + chunk_dict = _make_qvac_dict() + child = _make_chunk(text="Child text for citation.", score=0.7) + parent = _make_chunk(text="Full parent context block, much longer.", score=0.7) + + retrieve_resp = _mock_httpx_response({"chunks": [chunk_dict]}) + generate_resp = _mock_httpx_response({"answer": "Answer."}) + + with patch("app.services.chat_service._client") as mock_client, \ + patch("app.services.hybrid_search.bm25_search", return_value=[]), \ + patch("app.services.reranker.rerank", return_value=[child]), \ + patch("app.services.parent_expansion.expand_to_parents", return_value=[parent]): + + mock_client.post = AsyncMock(side_effect=[retrieve_resp, generate_resp]) + from app.services.chat_service import answer + result = await answer("What is Bitcoin?", "COURSE1") + + # Citations come from reranked (child), not context_chunks (parent) + assert result.citations[0].snippet == "Child text for citation."[:200] diff --git a/services/ai/tests/unit/test_chunker.py b/services/ai/tests/unit/test_chunker.py index d8c3182..74b53b6 100644 --- a/services/ai/tests/unit/test_chunker.py +++ b/services/ai/tests/unit/test_chunker.py @@ -9,6 +9,8 @@ """ import pytest +pytest.importorskip("module_3_micro_chunker", reason="module_3_micro_chunker not installed") + import app.workers.pipeline # noqa: F401 — sets up sys.modules alias + sys.path from module_3_micro_chunker import Chunker # noqa: E402 diff --git a/services/ai/tests/unit/test_hybrid_search.py b/services/ai/tests/unit/test_hybrid_search.py new file mode 100644 index 0000000..2d456df --- /dev/null +++ b/services/ai/tests/unit/test_hybrid_search.py @@ -0,0 +1,227 @@ +"""Unit tests for app.services.hybrid_search — BM25, RRF fusion, and index loading. + +Tests that require rank_bm25 are skipped automatically when the library is absent. +""" +import json +import pickle +from pathlib import Path +from unittest.mock import patch + +import pytest + +from app.schemas.evidence_pack import CitationAnchor, EvidenceChunk +from app.services.hybrid_search import bm25_search, load_bm25_index, rrf_fuse + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_evidence_chunk(chunk_id: str, score: float = 0.9) -> EvidenceChunk: + return EvidenceChunk( + chunk_id=chunk_id, + text=f"Text for {chunk_id}.", + score=score, + anchor=CitationAnchor( + doc_id="DOC1", + doc_name="Bitcoin Whitepaper", + section="Intro", + page=1, + slide=None, + chunk_id=chunk_id, + chunk_type="paragraph", + ), + ) + + +def _write_bm25_index(tmp_path: Path, ids: list, tokenized: list) -> None: + """Build and persist a BM25 index + corpus to tmp_path.""" + pytest.importorskip("rank_bm25") + from rank_bm25 import BM25Okapi + + bm25 = BM25Okapi(tokenized) + corpus = { + cid: {"text": " ".join(toks), "doc_id": "DOC1", "label": f"p. {i+1}", + "page": i + 1, "section": "Intro"} + for i, (cid, toks) in enumerate(zip(ids, tokenized)) + } + with (tmp_path / "COURSE1_bm25.pkl").open("wb") as f: + pickle.dump({"bm25": bm25, "ids": ids}, f) + with (tmp_path / "COURSE1_corpus.json").open("w") as f: + json.dump(corpus, f) + + +# --------------------------------------------------------------------------- +# load_bm25_index +# --------------------------------------------------------------------------- + +@pytest.mark.unit +def test_load_bm25_index_returns_none_when_files_missing(tmp_path): + with patch("app.services.hybrid_search._QVAC_INGEST_DIR", tmp_path): + result = load_bm25_index("NO_SUCH_COURSE") + assert result is None + + +@pytest.mark.unit +def test_load_bm25_index_returns_tuple_when_present(tmp_path): + pytest.importorskip("rank_bm25") + ids = ["c1", "c2"] + tokenized = [["bitcoin", "utxo"], ["proof", "work"]] + _write_bm25_index(tmp_path, ids, tokenized) + + with patch("app.services.hybrid_search._QVAC_INGEST_DIR", tmp_path): + result = load_bm25_index("COURSE1") + + assert result is not None + bm25_obj, returned_ids, corpus = result + assert returned_ids == ids + assert isinstance(corpus, dict) + assert "c1" in corpus + + +@pytest.mark.unit +def test_load_bm25_index_returns_none_on_corrupt_pickle(tmp_path): + (tmp_path / "COURSE1_bm25.pkl").write_bytes(b"not a valid pickle") + (tmp_path / "COURSE1_corpus.json").write_text("{}") + with patch("app.services.hybrid_search._QVAC_INGEST_DIR", tmp_path): + result = load_bm25_index("COURSE1") + assert result is None + + +# --------------------------------------------------------------------------- +# bm25_search +# --------------------------------------------------------------------------- + +@pytest.mark.unit +def test_bm25_search_returns_empty_when_index_missing(tmp_path): + with patch("app.services.hybrid_search._QVAC_INGEST_DIR", tmp_path): + result = bm25_search("bitcoin", "NO_COURSE", top_k=5) + assert result == [] + + +@pytest.mark.unit +def test_bm25_search_returns_ranked_tuples(tmp_path): + pytest.importorskip("rank_bm25") + ids = ["chunk_bitcoin", "chunk_mining"] + tokenized = [["bitcoin", "utxo", "transaction"], ["proof", "work", "mining"]] + _write_bm25_index(tmp_path, ids, tokenized) + + with patch("app.services.hybrid_search._QVAC_INGEST_DIR", tmp_path): + results = bm25_search("bitcoin utxo", "COURSE1", top_k=5) + + assert len(results) > 0 + assert all(isinstance(r, tuple) and len(r) == 2 for r in results) + # Most relevant chunk for "bitcoin utxo" should rank first + assert results[0][0] == "chunk_bitcoin" + + +@pytest.mark.unit +def test_bm25_search_excludes_zero_scores(tmp_path): + pytest.importorskip("rank_bm25") + ids = ["relevant", "irrelevant"] + tokenized = [["bitcoin", "utxo"], ["astronomy", "stars"]] + _write_bm25_index(tmp_path, ids, tokenized) + + with patch("app.services.hybrid_search._QVAC_INGEST_DIR", tmp_path): + results = bm25_search("bitcoin", "COURSE1", top_k=10) + + chunk_ids = [r[0] for r in results] + assert "irrelevant" not in chunk_ids + assert "relevant" in chunk_ids + + +@pytest.mark.unit +def test_bm25_search_respects_top_k(tmp_path): + pytest.importorskip("rank_bm25") + ids = [f"c{i}" for i in range(10)] + tokenized = [["bitcoin", f"token{i}"] for i in range(10)] + _write_bm25_index(tmp_path, ids, tokenized) + + with patch("app.services.hybrid_search._QVAC_INGEST_DIR", tmp_path): + results = bm25_search("bitcoin", "COURSE1", top_k=3) + + assert len(results) <= 3 + + +# --------------------------------------------------------------------------- +# rrf_fuse +# --------------------------------------------------------------------------- + +@pytest.mark.unit +def test_rrf_fuse_empty_inputs_return_empty(): + result = rrf_fuse([], [], {}, top_k=5) + assert result == [] + + +@pytest.mark.unit +def test_rrf_fuse_dense_only_returns_all(): + dense = [_make_evidence_chunk("c1"), _make_evidence_chunk("c2")] + result = rrf_fuse(dense, [], {}, top_k=10) + ids = [c.chunk_id for c in result] + assert "c1" in ids + assert "c2" in ids + + +@pytest.mark.unit +def test_rrf_fuse_shared_id_ranks_first(): + dense = [_make_evidence_chunk("shared", 0.9), _make_evidence_chunk("dense_only", 0.5)] + bm25_hits = [("shared", 2.5), ("bm25_only", 1.0)] + corpus = { + "bm25_only": {"text": "BM25 text", "doc_id": "D", "label": "p.1", + "page": 1, "section": ""} + } + result = rrf_fuse(dense, bm25_hits, corpus, top_k=10) + assert result[0].chunk_id == "shared" + + +@pytest.mark.unit +def test_rrf_fuse_respects_top_k(): + dense = [_make_evidence_chunk(f"d{i}") for i in range(10)] + bm25_hits = [(f"b{i}", float(10 - i)) for i in range(10)] + corpus = { + f"b{i}": {"text": f"BM25 text {i}", "doc_id": "D", "label": "p.1", + "page": 1, "section": ""} + for i in range(10) + } + result = rrf_fuse(dense, bm25_hits, corpus, top_k=5) + assert len(result) == 5 + + +@pytest.mark.unit +def test_rrf_fuse_bm25_only_chunks_reconstructed_from_corpus(): + # Only BM25 hit, not in dense_chunks + bm25_hits = [("bm25_only", 3.0)] + corpus = { + "bm25_only": { + "text": "BM25-only content", + "doc_id": "DOC1", + "label": "p. 5", + "page": 5, + "section": "Mining", + } + } + result = rrf_fuse([], bm25_hits, corpus, top_k=5) + assert len(result) == 1 + assert result[0].chunk_id == "bm25_only" + assert result[0].text == "BM25-only content" + assert result[0].anchor.page == 5 + + +@pytest.mark.unit +def test_rrf_fuse_scores_updated_to_rrf_value(): + dense = [_make_evidence_chunk("c1", score=0.99)] + result = rrf_fuse(dense, [], {}, top_k=5) + # RRF score is much smaller than cosine similarity + assert result[0].score < 0.1 + + +@pytest.mark.unit +def test_rrf_fuse_bm25_rank_order_preserved(): + bm25_hits = [("b1", 5.0), ("b2", 3.0)] + corpus = { + "b1": {"text": "B1", "doc_id": "D", "label": "p.1", "page": 1, "section": ""}, + "b2": {"text": "B2", "doc_id": "D", "label": "p.2", "page": 2, "section": ""}, + } + result = rrf_fuse([], bm25_hits, corpus, top_k=10) + ids = [c.chunk_id for c in result] + assert ids.index("b1") < ids.index("b2") diff --git a/services/ai/tests/unit/test_ingester_parser.py b/services/ai/tests/unit/test_ingester_parser.py index f0aad7a..e3fdaa8 100644 --- a/services/ai/tests/unit/test_ingester_parser.py +++ b/services/ai/tests/unit/test_ingester_parser.py @@ -15,6 +15,9 @@ import pytest +pytest.importorskip("module_1_ingestor", reason="module_1_ingestor not installed") +pytest.importorskip("module_2_parser", reason="module_2_parser not installed") + import app.workers.pipeline # noqa: F401 — triggers alias + sys.path setup from module_1_ingestor import RamSafeIngestor # noqa: E402 from module_2_parser import StructuralParser # noqa: E402