diff --git a/context_use/core.py b/context_use/core.py index a614bb8..50a45fe 100644 --- a/context_use/core.py +++ b/context_use/core.py @@ -28,7 +28,7 @@ get_memory_config, get_memory_interaction_types, ) -from context_use.store.base import MemorySearchResult +from context_use.store.base import MemorySearchResult, ThreadSearchResult from context_use.types import PipelineResult, TaskBreakdown if TYPE_CHECKING: @@ -442,6 +442,21 @@ async def insert_threads( """Insert thread rows into the store, deduplicating on ``unique_key``.""" return await self._store.insert_threads(rows, task_id) + async def search_threads( + self, + query: str, + *, + top_k: int = 10, + interaction_types: list[str] | None = None, + ) -> list[ThreadSearchResult]: + """Search threads by semantic similarity.""" + query_embedding = await self._llm_client.embed_query(query) + return await self._store.search_threads( + query_embedding=query_embedding, + top_k=top_k, + interaction_types=interaction_types, + ) + # ── Private helpers ────────────────────────────────────────────── def _batch_context(self) -> BatchContext: diff --git a/context_use/store/base.py b/context_use/store/base.py index b823661..b161692 100644 --- a/context_use/store/base.py +++ b/context_use/store/base.py @@ -37,6 +37,17 @@ class MemorySearchResult: similarity: float | None +@dataclass(frozen=True) +class ThreadSearchResult: + """A thread search hit with similarity score.""" + + id: str + interaction_type: str + content: str + asat: datetime + similarity: float + + class Store(ABC): """Abstract store for all context_use domain entities. @@ -265,6 +276,21 @@ async def upsert_thread_embedding( """Insert or replace the embedding vector for a thread.""" ... + @abstractmethod + async def search_threads( + self, + *, + query_embedding: list[float], + top_k: int = 10, + interaction_types: list[str] | None = None, + ) -> list[ThreadSearchResult]: + """Search threads by semantic similarity. + + If *interaction_types* is given, only threads whose + ``interaction_type`` is in that list are returned. + """ + ... + # ── Memory Facets ──────────────────────────────────────────────── @abstractmethod diff --git a/context_use/store/sqlite/store.py b/context_use/store/sqlite/store.py index f8766f8..773d9ec 100644 --- a/context_use/store/sqlite/store.py +++ b/context_use/store/sqlite/store.py @@ -22,7 +22,12 @@ Thread, ) from context_use.models.utils import generate_uuidv4 -from context_use.store.base import MemorySearchResult, SortOrder, Store +from context_use.store.base import ( + MemorySearchResult, + SortOrder, + Store, + ThreadSearchResult, +) from context_use.store.sqlite.schema import ( ArchiveRow, BatchRow, @@ -665,6 +670,52 @@ async def upsert_thread_embedding( ) await self._commit_unless_atomic() + async def search_threads( + self, + *, + query_embedding: list[float], + top_k: int = 10, + interaction_types: list[str] | None = None, + ) -> list[ThreadSearchResult]: + db = await self._conn() + vec_rows = await db.execute_fetchall( + "SELECT thread_id, distance FROM vec_threads " + "WHERE embedding MATCH ? AND k = ?", + (VecThreadRow.serialize(query_embedding), top_k * 4), + ) + if not vec_rows: + return [] + + candidate_ids = [r[0] for r in vec_rows] + distances: dict[str, float] = {r[0]: r[1] for r in vec_rows} + + ph = ",".join("?" for _ in candidate_ids) + sql = ( + "SELECT id, interaction_type, content, asat " + f"FROM threads WHERE id IN ({ph})" + ) + params: list = list(candidate_ids) + if interaction_types is not None: + type_ph = ",".join("?" for _ in interaction_types) + sql += f" AND interaction_type IN ({type_ph})" + params.extend(interaction_types) + + thread_rows = await db.execute_fetchall(sql, params) + + results = [ + ThreadSearchResult( + id=r["id"], + interaction_type=r["interaction_type"], + content=r["content"] or "", + asat=parse_dt(r["asat"]), + similarity=1.0 - distances[r["id"]], + ) + for r in thread_rows + if r["id"] in distances + ] + results.sort(key=lambda x: x.similarity, reverse=True) + return results[:top_k] + async def create_memory_facet(self, facet: MemoryFacet) -> MemoryFacet: db = await self._conn() await db.execute( diff --git a/tests/unit/store/test_sqlite.py b/tests/unit/store/test_sqlite.py index d423345..6f4ce7b 100644 --- a/tests/unit/store/test_sqlite.py +++ b/tests/unit/store/test_sqlite.py @@ -515,6 +515,121 @@ async def test_upsert_thread_embedding_replaces_existing(store: SqliteStore) -> assert len(result) == 1 +async def _insert_thread_with_embedding( + store: SqliteStore, + *, + unique_key: str, + interaction_type: str = "test_type", + content: str | None = None, + embedding: list[float], +) -> str: + """Helper: insert a thread and its embedding, return the thread ID.""" + archive = Archive(provider="test") + await store.create_archive(archive) + task = EtlTask( + archive_id=archive.id, + provider="test", + interaction_type=interaction_type, + source_uris=["test.json"], + ) + await store.create_task(task) + rows = [ + ThreadRow( + unique_key=unique_key, + provider="test", + interaction_type=interaction_type, + preview="p", + payload={ + "type": "Create", + "fibre_kind": "Create", + "object": {"type": "Note"}, + }, + version="1.0", + asat=datetime(2025, 1, 1, tzinfo=UTC), + ) + ] + ids = await store.insert_threads(rows, task.id) + thread_id = ids[0] + if content is not None: + await store.update_thread_content(thread_id, content) + await store.upsert_thread_embedding(thread_id, embedding) + return thread_id + + +async def test_search_threads_by_embedding(store: SqliteStore) -> None: + emb_similar = [1.0, 0.0, 0.0, 0.0] + emb_different = [0.0, 0.0, 0.0, 1.0] + + id_similar = await _insert_thread_with_embedding( + store, + unique_key="uk-similar", + content="similar content", + embedding=emb_similar, + ) + await _insert_thread_with_embedding( + store, + unique_key="uk-different", + content="different content", + embedding=emb_different, + ) + + query_emb = [0.9, 0.1, 0.0, 0.0] + results = await store.search_threads(query_embedding=query_emb, top_k=1) + + assert len(results) == 1 + assert results[0].id == id_similar + assert results[0].content == "similar content" + assert results[0].similarity > 0.0 + + +async def test_search_threads_filters_by_interaction_type(store: SqliteStore) -> None: + emb = [1.0, 0.0, 0.0, 0.0] + + await _insert_thread_with_embedding( + store, + unique_key="uk-type-a", + interaction_type="type_a", + content="content a", + embedding=emb, + ) + id_b = await _insert_thread_with_embedding( + store, + unique_key="uk-type-b", + interaction_type="type_b", + content="content b", + embedding=emb, + ) + + results = await store.search_threads( + query_embedding=emb, + top_k=10, + interaction_types=["type_b"], + ) + + assert len(results) == 1 + assert results[0].id == id_b + + +async def test_search_threads_returns_empty_when_no_embeddings( + store: SqliteStore, +) -> None: + results = await store.search_threads(query_embedding=[1.0, 0.0, 0.0, 0.0], top_k=5) + assert results == [] + + +async def test_search_threads_respects_top_k(store: SqliteStore) -> None: + for i in range(5): + await _insert_thread_with_embedding( + store, + unique_key=f"uk-topk-{i}", + content=f"content {i}", + embedding=[1.0, 0.0, 0.0, 0.0], + ) + + results = await store.search_threads(query_embedding=[1.0, 0.0, 0.0, 0.0], top_k=3) + assert len(results) == 3 + + async def test_atomic_commits_on_success(store: SqliteStore) -> None: async with store.atomic(): archive = Archive(provider="test") diff --git a/tests/unit/thread_embedding/test_search.py b/tests/unit/thread_embedding/test_search.py new file mode 100644 index 0000000..1b0d695 --- /dev/null +++ b/tests/unit/thread_embedding/test_search.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from context_use.core import ContextUse +from context_use.store.base import ThreadSearchResult + + +def _make_ctx(*, search_results: list[ThreadSearchResult]) -> ContextUse: + store = AsyncMock() + store.search_threads = AsyncMock(return_value=search_results) + + llm_client = MagicMock() + llm_client.embed_query = AsyncMock(return_value=[1.0, 0.0, 0.0]) + + ctx = object.__new__(ContextUse) + ctx._store = store + ctx._llm_client = llm_client + ctx._storage = MagicMock() + return ctx + + +class TestSearchThreads: + @pytest.mark.asyncio + async def test_embeds_query_and_delegates_to_store(self) -> None: + from datetime import UTC, datetime + + hit = ThreadSearchResult( + id="t1", + interaction_type="chatgpt_conversations", + content="hello world", + asat=datetime(2025, 1, 1, tzinfo=UTC), + similarity=0.95, + ) + ctx = _make_ctx(search_results=[hit]) + + results = await ctx.search_threads("hello") + + assert len(results) == 1 + assert results[0].id == "t1" + assert results[0].similarity == 0.95 + + ctx._llm_client.embed_query.assert_awaited_once_with("hello") # type: ignore[union-attr] + ctx._store.search_threads.assert_awaited_once() # type: ignore[union-attr] + call_kwargs = ctx._store.search_threads.call_args.kwargs # type: ignore[union-attr] + assert call_kwargs["query_embedding"] == [1.0, 0.0, 0.0] + assert call_kwargs["top_k"] == 10 + + @pytest.mark.asyncio + async def test_forwards_top_k_and_interaction_types(self) -> None: + ctx = _make_ctx(search_results=[]) + + await ctx.search_threads( + "query", + top_k=5, + interaction_types=["instagram_posts"], + ) + + call_kwargs = ctx._store.search_threads.call_args.kwargs # type: ignore[union-attr] + assert call_kwargs["top_k"] == 5 + assert call_kwargs["interaction_types"] == ["instagram_posts"] + + @pytest.mark.asyncio + async def test_returns_empty_when_no_results(self) -> None: + ctx = _make_ctx(search_results=[]) + results = await ctx.search_threads("nothing") + assert results == []