Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions backend/models/schemas.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
"""Pydantic v2 schemas for LocalMind API."""

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from typing import Optional, List
from datetime import datetime
from enum import Enum


class SourceChunk(BaseModel):
"""A single retrieved document chunk attached to an assistant message."""

source: str
"""Original filename (e.g. 'report.pdf')."""

chunk: int = 0
"""Zero-based chunk index within the document."""

preview: str = ""
"""Up to 300 characters of the retrieved chunk text for inline preview."""


class MessageRole(str, Enum):
user = "user"
assistant = "assistant"
Expand All @@ -16,7 +29,28 @@ class ChatMessage(BaseModel):
role: MessageRole
content: str
timestamp: Optional[datetime] = None
sources: List[str] = []
sources: List[SourceChunk] = []

@field_validator("sources", mode="before")
@classmethod
def normalize_sources(cls, v: list) -> list:
"""Coerce legacy string source entries into SourceChunk objects.

Old sessions stored sources as a plain JSON array of filename strings,
e.g. ["report.pdf", "notes.txt"]. New sessions store structured dicts.
This validator accepts both shapes and always produces List[SourceChunk],
so no database migration is required.
"""
if not isinstance(v, list):
return v
normalized = []
for item in v:
if isinstance(item, str):
# Legacy format: bare filename string → SourceChunk with empty preview
normalized.append(SourceChunk(source=item))
else:
normalized.append(item)
return normalized


class ChatRequest(BaseModel):
Expand All @@ -32,7 +66,7 @@ class ChatResponse(BaseModel):
reply: str
session_id: str
model: str
sources: List[str] = []
sources: List[SourceChunk] = []
tokens_used: Optional[int] = None


Expand Down
43 changes: 43 additions & 0 deletions backend/services/citation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Citation utilities — pure Python helpers with no external dependencies.

Kept separate from rag_service so they can be imported and unit-tested
without triggering the chromadb / sentence-transformers import chain.
"""

from __future__ import annotations

PREVIEW_MAX_CHARS = 300


def build_sources(docs: list[str], metas: list[dict]) -> list[dict]:
"""Build a structured source list from ChromaDB result rows.

Returns one entry per unique (filename, chunk-index) pair. Each entry
carries a short preview of the retrieved text — suitable for inline
citation display in the frontend.

Args:
docs: Retrieved document chunk texts (parallel with *metas*).
metas: Metadata dicts from ChromaDB, each expected to have at least
``source`` (filename) and ``chunk`` (zero-based index) keys.

Returns:
List of dicts with keys: ``source`` (str), ``chunk`` (int),
``preview`` (str — up to PREVIEW_MAX_CHARS characters).
"""
seen: dict[tuple[str, int], dict] = {}
for doc, meta in zip(docs, metas):
key = (meta.get("source", "unknown"), meta.get("chunk", 0))
if key not in seen:
preview = (
doc[:PREVIEW_MAX_CHARS] + "..."
if len(doc) > PREVIEW_MAX_CHARS
else doc
)
seen[key] = {
"source": meta.get("source", "unknown"),
"chunk": meta.get("chunk", 0),
"preview": preview,
}
return list(seen.values())
9 changes: 7 additions & 2 deletions backend/services/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
)
from sentence_transformers import SentenceTransformer

from services.citation_utils import build_sources

logger = logging.getLogger(__name__)

CHROMA_PATH = os.getenv("CHROMADB_DIR", "./data/chromadb")
Expand Down Expand Up @@ -72,7 +74,7 @@ def index_document(file_path: str, session_id: str) -> int:
return len(chunks)


def retrieve_context(query: str, session_id: str, top_k: int = 4) -> tuple[str, list[str]]:
def retrieve_context(query: str, session_id: str, top_k: int = 4) -> tuple[str, list[dict]]:
col = _collection(session_id)
if col.count() == 0:
return "", []
Expand All @@ -88,7 +90,10 @@ def retrieve_context(query: str, session_id: str, top_k: int = 4) -> tuple[str,
metas = results["metadatas"][0] if results["metadatas"] else []

context = "\n\n---\n\n".join(docs)
sources = list({m.get("source", "unknown") for m in metas})

# Build structured source list: one entry per unique (filename, chunk) pair,
# preserving a short preview of the retrieved text for inline citation display.
sources = build_sources(docs, metas)
return context, sources


Expand Down
235 changes: 235 additions & 0 deletions backend/tests/test_citations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""
Tests for inline citation previews.

Covers:
- _build_sources() returns structured List[dict] with source/chunk/preview
- Preview is truncated to 300 chars + "..."
- Duplicate (source, chunk) pairs are collapsed to one entry
- ChatMessage.sources accepts both legacy List[str] and new List[dict] (backward compat)
- Chat endpoint returns SourceChunk-shaped objects in its JSON response
"""

import tempfile
from unittest.mock import AsyncMock, patch

from fastapi.testclient import TestClient

import services.db_service as db
from app import app
from models.schemas import ChatMessage, MessageRole, SourceChunk

# ─── Shared test client ──────────────────────────────────────────
_tmp = tempfile.mktemp(suffix="_citations.db")
db.DB_PATH = _tmp
db.init_db()

client = TestClient(app)


# ─── _build_sources() pure helper ───────────────────────────────
# Import only the pure helper — no chromadb / sentence_transformers needed.
from services.citation_utils import build_sources # noqa: E402


class TestBuildSources:
"""Unit-test the pure build_sources() helper in complete isolation."""

def test_returns_list_of_dicts(self):
docs = ["Hello world chunk text."]
metas = [{"source": "file.pdf", "chunk": 0}]
sources = build_sources(docs, metas)
assert isinstance(sources, list)
assert isinstance(sources[0], dict)

def test_source_dict_has_required_keys(self):
docs = ["Some retrieved text."]
metas = [{"source": "notes.txt", "chunk": 3}]
s = build_sources(docs, metas)[0]
assert s["source"] == "notes.txt"
assert s["chunk"] == 3
assert "preview" in s

def test_preview_includes_chunk_text(self):
docs = ["The capital of France is Paris."]
metas = [{"source": "geo.pdf", "chunk": 1}]
s = build_sources(docs, metas)[0]
assert "Paris" in s["preview"]

def test_preview_truncated_at_300_chars(self):
long_text = "A" * 400
docs = [long_text]
metas = [{"source": "big.txt", "chunk": 0}]
s = build_sources(docs, metas)[0]
assert len(s["preview"]) <= 304 # 300 chars + "..."
assert s["preview"].endswith("...")

def test_short_text_not_truncated(self):
short = "Short text."
docs = [short]
metas = [{"source": "small.txt", "chunk": 0}]
s = build_sources(docs, metas)[0]
assert s["preview"] == short
assert not s["preview"].endswith("...")

def test_duplicate_source_chunk_collapsed(self):
"""Two rows with the same (filename, chunk) → one source entry."""
docs = ["Chunk text A.", "Chunk text A."]
metas = [
{"source": "dup.pdf", "chunk": 2},
{"source": "dup.pdf", "chunk": 2},
]
assert len(build_sources(docs, metas)) == 1

def test_different_chunks_same_file_kept_separate(self):
docs = ["First chunk.", "Second chunk."]
metas = [
{"source": "report.pdf", "chunk": 0},
{"source": "report.pdf", "chunk": 1},
]
assert len(build_sources(docs, metas)) == 2

def test_multiple_files(self):
docs = ["Alpha.", "Beta."]
metas = [
{"source": "a.pdf", "chunk": 0},
{"source": "b.pdf", "chunk": 0},
]
sources = build_sources(docs, metas)
names = {s["source"] for s in sources}
assert names == {"a.pdf", "b.pdf"}

def test_empty_inputs(self):
assert build_sources([], []) == []

def test_missing_metadata_keys_use_defaults(self):
docs = ["Some text."]
metas = [{}] # no "source" or "chunk" keys
s = build_sources(docs, metas)[0]
assert s["source"] == "unknown"
assert s["chunk"] == 0



# ─── Backward compatibility: ChatMessage accepts both shapes ─────

class TestChatMessageBackwardCompat:
"""ChatMessage.sources must accept legacy List[str] and new List[dict]."""

def test_legacy_string_sources_accepted(self):
msg = ChatMessage(
role=MessageRole.assistant,
content="Answer",
sources=["report.pdf", "notes.txt"],
)
assert len(msg.sources) == 2
assert all(isinstance(s, SourceChunk) for s in msg.sources)
assert msg.sources[0].source == "report.pdf"

def test_structured_dict_sources_accepted(self):
msg = ChatMessage(
role=MessageRole.assistant,
content="Answer",
sources=[{"source": "report.pdf", "chunk": 0, "preview": "Some text"}],
)
assert isinstance(msg.sources[0], SourceChunk)
assert msg.sources[0].source == "report.pdf"

def test_empty_sources_accepted(self):
msg = ChatMessage(role=MessageRole.user, content="Hi")
assert msg.sources == []

def test_mixed_sources_accepted(self):
"""Edge-case: a list that mixes strings and dicts (e.g. partial migration)."""
msg = ChatMessage(
role=MessageRole.assistant,
content="Answer",
sources=["legacy.pdf", {"source": "new.txt", "chunk": 0, "preview": "text"}],
)
assert len(msg.sources) == 2


# ─── SourceChunk schema ──────────────────────────────────────────

class TestSourceChunkSchema:
def test_defaults(self):
sc = SourceChunk(source="file.pdf")
assert sc.chunk == 0
assert sc.preview == ""

def test_full_construction(self):
sc = SourceChunk(source="file.pdf", chunk=3, preview="Some extracted text.")
assert sc.source == "file.pdf"
assert sc.chunk == 3
assert sc.preview == "Some extracted text."

def test_serialization(self):
sc = SourceChunk(source="doc.pdf", chunk=1, preview="Preview text.")
d = sc.model_dump()
assert d == {"source": "doc.pdf", "chunk": 1, "preview": "Preview text."}


# ─── Chat endpoint returns SourceChunk-shaped sources ────────────

@patch("routes.chat.ollama_service.is_ollama_running", new_callable=AsyncMock, return_value=True)
@patch("routes.chat.ollama_service.chat", new_callable=AsyncMock, return_value="Here is the answer.")
@patch(
"routes.chat.rag_service.retrieve_context",
return_value=(
"context text",
[{"source": "doc.pdf", "chunk": 0, "preview": "Relevant excerpt from doc."}],
),
)
def test_chat_endpoint_returns_source_chunks(m_rag, m_chat, m_ollama):
r = client.post("/api/sessions/", json={"title": "Citation Test"})
sid = r.json()["id"]

r2 = client.post(
"/api/chat/",
json={"message": "What does the doc say?", "session_id": sid, "model": "llama3", "use_documents": True},
)
assert r2.status_code == 200
data = r2.json()
assert len(data["sources"]) == 1
src = data["sources"][0]
assert src["source"] == "doc.pdf"
assert src["chunk"] == 0
assert "Relevant excerpt" in src["preview"]


@patch("routes.chat.ollama_service.is_ollama_running", new_callable=AsyncMock, return_value=True)
@patch("routes.chat.ollama_service.chat", new_callable=AsyncMock, return_value="No docs needed.")
@patch("routes.chat.rag_service.retrieve_context", return_value=("", []))
def test_chat_endpoint_no_documents_empty_sources(m_rag, m_chat, m_ollama):
r = client.post("/api/sessions/", json={"title": "No Doc Test"})
sid = r.json()["id"]

r2 = client.post(
"/api/chat/",
json={"message": "Hello", "session_id": sid, "model": "llama3", "use_documents": False},
)
assert r2.status_code == 200
assert r2.json()["sources"] == []


# ─── Round-trip: sources saved & loaded from SQLite ──────────────

def test_sources_roundtrip_structured():
"""Structured source dicts survive JSON serialization through db_service."""
sources = [{"source": "report.pdf", "chunk": 2, "preview": "Some text here."}]
r = client.post("/api/sessions/", json={"title": "RT Test"})
sid = r.json()["id"]
db.save_message(sid, "assistant", "An answer.", sources)
msgs = db.get_messages_full(sid)
loaded = msgs[-1]["sources"]
assert loaded[0]["source"] == "report.pdf"
assert loaded[0]["preview"] == "Some text here."


def test_sources_roundtrip_legacy_strings():
"""Legacy string sources survive JSON serialization through db_service."""
sources = ["legacy.pdf", "old_notes.txt"]
r = client.post("/api/sessions/", json={"title": "Legacy RT Test"})
sid = r.json()["id"]
db.save_message(sid, "assistant", "An answer.", sources)
msgs = db.get_messages_full(sid)
assert msgs[-1]["sources"] == ["legacy.pdf", "old_notes.txt"]
Loading
Loading