From 3aad98f5d80accaa36ef1ab1ccce0f4d46fe85ab Mon Sep 17 00:00:00 2001 From: Alex Kulikov Date: Thu, 16 Apr 2026 14:41:22 +0100 Subject: [PATCH 1/2] feat: embedding pipeline --- context_use/cli/base.py | 4 + context_use/cli/commands/__init__.py | 2 + context_use/cli/commands/embed.py | 74 ++++++++ context_use/core.py | 32 ++++ context_use/models/batch.py | 1 + context_use/models/thread.py | 12 ++ context_use/store/base.py | 9 + context_use/store/sqlite/schema.py | 22 +++ context_use/store/sqlite/store.py | 15 ++ context_use/thread_embedding/__init__.py | 0 context_use/thread_embedding/embedding.py | 57 ++++++ context_use/thread_embedding/factory.py | 11 ++ context_use/thread_embedding/manager.py | 75 ++++++++ context_use/thread_embedding/states.py | 72 ++++++++ tests/unit/models/test_thread.py | 49 +++++ tests/unit/store/test_sqlite.py | 83 +++++++++ tests/unit/thread_embedding/__init__.py | 0 .../thread_embedding/test_batch_creation.py | 133 ++++++++++++++ tests/unit/thread_embedding/test_embedding.py | 99 ++++++++++ tests/unit/thread_embedding/test_manager.py | 169 ++++++++++++++++++ tests/unit/thread_embedding/test_states.py | 68 +++++++ 21 files changed, 987 insertions(+) create mode 100644 context_use/cli/commands/embed.py create mode 100644 context_use/thread_embedding/__init__.py create mode 100644 context_use/thread_embedding/embedding.py create mode 100644 context_use/thread_embedding/factory.py create mode 100644 context_use/thread_embedding/manager.py create mode 100644 context_use/thread_embedding/states.py create mode 100644 tests/unit/thread_embedding/__init__.py create mode 100644 tests/unit/thread_embedding/test_batch_creation.py create mode 100644 tests/unit/thread_embedding/test_embedding.py create mode 100644 tests/unit/thread_embedding/test_manager.py create mode 100644 tests/unit/thread_embedding/test_states.py diff --git a/context_use/cli/base.py b/context_use/cli/base.py index 610b745c..e8632165 100644 --- a/context_use/cli/base.py +++ b/context_use/cli/base.py @@ -30,6 +30,7 @@ def _batch_detail_from_state(state: State | None) -> str: MemoryEmbedCompleteState, MemoryGenerateCompleteState, ) + from context_use.thread_embedding.states import ThreadEmbedCompleteState if isinstance(state, FailedState): message = state.error_message.strip() @@ -50,6 +51,9 @@ def _batch_detail_from_state(state: State | None) -> str: if isinstance(state, DescGenerateCompleteState): return f"{state.descriptions_count} descriptions generated" + if isinstance(state, ThreadEmbedCompleteState): + return f"{state.embedded_count} threads embedded" + return "" diff --git a/context_use/cli/commands/__init__.py b/context_use/cli/commands/__init__.py index 2bba8e48..db4c4c16 100644 --- a/context_use/cli/commands/__init__.py +++ b/context_use/cli/commands/__init__.py @@ -4,6 +4,7 @@ from context_use.cli.commands.agent import AgentGroup from context_use.cli.commands.config import ConfigGroup from context_use.cli.commands.describe import DescribeCommand +from context_use.cli.commands.embed import EmbedCommand from context_use.cli.commands.ingest import IngestCommand from context_use.cli.commands.memories import MemoriesGroup from context_use.cli.commands.pipeline import PipelineCommand @@ -15,6 +16,7 @@ PipelineCommand, IngestCommand, DescribeCommand, + EmbedCommand, ResetCommand, ] diff --git a/context_use/cli/commands/embed.py b/context_use/cli/commands/embed.py new file mode 100644 index 00000000..b70c8f22 --- /dev/null +++ b/context_use/cli/commands/embed.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import argparse +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING + +from context_use.cli import output as out +from context_use.cli.base import ApiCommand, create_batch_reporter, run_batches +from context_use.cli.config import Config + +if TYPE_CHECKING: + from context_use import ContextUse + + +class EmbedCommand(ApiCommand): + name = "embed" + help = "Embed thread content for semantic search" + description = ( + "Generate vector embeddings for all unprocessed threads. " + "Asset threads require descriptions first (run 'describe' beforehand). " + "Use --last-days or --since to limit the date range." + ) + llm_mode = "batch" + + def add_arguments(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--last-days", + type=int, + default=None, + help="Only process threads from the last N days", + ) + parser.add_argument( + "--since", + type=str, + default=None, + help="Only process threads after this date (YYYY-MM-DD)", + ) + + async def run( + self, + cfg: Config, + ctx: ContextUse, + args: argparse.Namespace, + ) -> None: + since = self._resolve_since(args) + + out.header("Embedding threads") + out.info("Embeds all threads that have not been embedded yet.") + if since: + out.kv("Since", since.strftime("%Y-%m-%d")) + print() + + batches = await ctx.create_thread_embedding_batches(since=since) + + if not batches: + out.info("No threads to embed.") + return + + await run_batches( + ctx, + batches, + reporter_factory=create_batch_reporter, + ) + + out.success("Thread embeddings generated") + out.kv("Batches", len(batches)) + print() + + def _resolve_since(self, args: argparse.Namespace) -> datetime | None: + if args.since: + return datetime.fromisoformat(args.since).replace(tzinfo=UTC) + if args.last_days is not None: + return datetime.now(UTC) - timedelta(days=args.last_days) + return None diff --git a/context_use/core.py b/context_use/core.py index 6c863337..a614bb8f 100644 --- a/context_use/core.py +++ b/context_use/core.py @@ -235,6 +235,37 @@ async def create_asset_description_batches( groups = [ThreadGroup(threads=[t], group_id=t.id) for t in asset_threads] return await AssetDescriptionBatchFactory.create_batches(groups, self._store) + # ── Thread embedding batches ───────────────────────────────────── + + async def create_thread_embedding_batches( + self, + *, + task_id: str | None = None, + since: datetime | None = None, + before: datetime | None = None, + ) -> list[Batch]: + """Create batches for thread embedding. + + Every thread with embeddable content is included. Asset threads + without a description yet are silently skipped (their + ``get_embeddable_content()`` returns ``None``). + """ + from context_use.thread_embedding.factory import ThreadEmbeddingBatchFactory + + threads = await self._store.get_unprocessed_threads( + batch_category=BatchCategory.thread_embedding.value, + task_id=task_id, + since=since, + before=before, + ) + + embeddable = [t for t in threads if t.get_embeddable_content() is not None] + if not embeddable: + return [] + + groups = [ThreadGroup(threads=[t], group_id=t.id) for t in embeddable] + return await ThreadEmbeddingBatchFactory.create_batches(groups, self._store) + # ── Memory batches ──────────────────────────────────────────────── async def create_memory_batches( @@ -452,3 +483,4 @@ def _ensure_managers_registered() -> None: """Import manager modules to trigger their @register_batch_manager decorators.""" import context_use.asset_description.manager # noqa: F401 import context_use.memories.manager # noqa: F401 + import context_use.thread_embedding.manager # noqa: F401 diff --git a/context_use/models/batch.py b/context_use/models/batch.py index 29f068ba..b947e8f8 100644 --- a/context_use/models/batch.py +++ b/context_use/models/batch.py @@ -20,6 +20,7 @@ class BatchCategory(enum.StrEnum): memories = "memories" asset_description = "asset_description" + thread_embedding = "thread_embedding" @dataclass diff --git a/context_use/models/thread.py b/context_use/models/thread.py index f98d973f..15d63af5 100644 --- a/context_use/models/thread.py +++ b/context_use/models/thread.py @@ -82,6 +82,18 @@ def get_raw_content(self) -> str: """Return semantic content from the payload, ignoring any enriched content.""" return self._parsed_payload.get_content() or "" + def get_embeddable_content(self) -> str | None: + """Return text suitable for embedding, or ``None`` to skip. + + Asset threads use the enriched ``content`` (set by the describe + pipeline). If no description exists yet the thread is not ready + for embedding. Non-asset threads use the raw payload content. + """ + if self.is_asset: + return self.content + raw = self.get_raw_content() + return raw or None + def get_participant_label(self) -> str: return self._parsed_payload.get_participant_label() diff --git a/context_use/store/base.py b/context_use/store/base.py index efb45e63..b8236610 100644 --- a/context_use/store/base.py +++ b/context_use/store/base.py @@ -256,6 +256,15 @@ async def search_memories( """Search memories by semantic similarity, optionally filtered by date range.""" ... + # ── Thread Embeddings ───────────────────────────────────────────── + + @abstractmethod + async def upsert_thread_embedding( + self, thread_id: str, embedding: list[float] + ) -> None: + """Insert or replace the embedding vector for a thread.""" + ... + # ── Memory Facets ──────────────────────────────────────────────── @abstractmethod diff --git a/context_use/store/sqlite/schema.py b/context_use/store/sqlite/schema.py index 6c7acead..78adba61 100644 --- a/context_use/store/sqlite/schema.py +++ b/context_use/store/sqlite/schema.py @@ -381,6 +381,27 @@ def from_row(row: Row) -> MemoryFacet: ) +class VecThreadRow: + table = "vec_threads" + + @classmethod + def ddl(cls, embedding_dimensions: int) -> str: + return ( + "CREATE VIRTUAL TABLE IF NOT EXISTS vec_threads " + f"USING vec0(\n" + f" thread_id TEXT PRIMARY KEY,\n" + f" embedding float[{embedding_dimensions}] " + f"distance_metric=cosine\n" + f")" + ) + + @staticmethod + def serialize(embedding: list[float]) -> bytes: + from sqlite_vec import serialize_float32 + + return serialize_float32(embedding) + + class VecFacetRow: table = "vec_facets" @@ -420,5 +441,6 @@ def all_ddl_statements(embedding_dimensions: int) -> list[str]: stmts.append(model.ddl()) stmts.extend(model.indices()) stmts.append(VecMemoryRow.ddl(embedding_dimensions)) + stmts.append(VecThreadRow.ddl(embedding_dimensions)) stmts.append(VecFacetRow.ddl(embedding_dimensions)) return stmts diff --git a/context_use/store/sqlite/store.py b/context_use/store/sqlite/store.py index a64a2ecb..f8766f8b 100644 --- a/context_use/store/sqlite/store.py +++ b/context_use/store/sqlite/store.py @@ -33,6 +33,7 @@ ThreadRow, VecFacetRow, VecMemoryRow, + VecThreadRow, all_ddl_statements, now_utc_iso, parse_dt, @@ -650,6 +651,20 @@ async def search_memories( top_k=top_k, ) + async def upsert_thread_embedding( + self, thread_id: str, embedding: list[float] + ) -> None: + db = await self._conn() + await db.execute( + "DELETE FROM vec_threads WHERE thread_id = ?", + (thread_id,), + ) + await db.execute( + "INSERT INTO vec_threads (thread_id, embedding) VALUES (?, ?)", + (thread_id, VecThreadRow.serialize(embedding)), + ) + await self._commit_unless_atomic() + async def create_memory_facet(self, facet: MemoryFacet) -> MemoryFacet: db = await self._conn() await db.execute( diff --git a/context_use/thread_embedding/__init__.py b/context_use/thread_embedding/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/context_use/thread_embedding/embedding.py b/context_use/thread_embedding/embedding.py new file mode 100644 index 00000000..11d9b25b --- /dev/null +++ b/context_use/thread_embedding/embedding.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from context_use.llm.base import BaseLLMClient, EmbedBatchResults, EmbedItem +from context_use.models.thread import Thread + +if TYPE_CHECKING: + from context_use.store.base import Store + +logger = logging.getLogger(__name__) + + +async def submit_thread_embeddings( + threads: list[Thread], + batch_id: str, + llm_client: BaseLLMClient, +) -> tuple[str, list[str]]: + """Submit an embedding batch for *threads*. + + Only threads with embeddable content are included. + Returns ``(job_key, embedded_thread_ids)``. + """ + items: list[EmbedItem] = [] + included_ids: list[str] = [] + for t in threads: + text = t.get_embeddable_content() + if text is None: + continue + items.append(EmbedItem(item_id=t.id, text=text)) + included_ids.append(t.id) + + if not items: + raise ValueError("No threads with embeddable content") + + logger.info("[%s] Submitting embed batch for %d threads", batch_id, len(items)) + job_key = await llm_client.embed_batch_submit(batch_id, items) + return job_key, included_ids + + +async def store_thread_embeddings( + results: EmbedBatchResults, + batch_id: str, + store: Store, +) -> int: + """Write embedding vectors into the vec_threads table. + + Returns count stored. + """ + count = 0 + for thread_id, vector in results.items(): + await store.upsert_thread_embedding(thread_id, vector) + count += 1 + + logger.info("[%s] Stored %d thread embeddings", batch_id, count) + return count diff --git a/context_use/thread_embedding/factory.py b/context_use/thread_embedding/factory.py new file mode 100644 index 00000000..2dd11fd6 --- /dev/null +++ b/context_use/thread_embedding/factory.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from typing import ClassVar + +from context_use.batch.factory import BaseBatchFactory +from context_use.models.batch import BatchCategory + + +class ThreadEmbeddingBatchFactory(BaseBatchFactory): + BATCH_CATEGORIES: ClassVar[list[BatchCategory]] = [BatchCategory.thread_embedding] + MAX_GROUPS_PER_BATCH = 100 diff --git a/context_use/thread_embedding/manager.py b/context_use/thread_embedding/manager.py new file mode 100644 index 00000000..0513dc70 --- /dev/null +++ b/context_use/thread_embedding/manager.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import logging + +from context_use.batch.manager import ( + BaseBatchManager, + BatchContext, + register_batch_manager, +) +from context_use.batch.states import CompleteState, CreatedState, SkippedState, State +from context_use.models.batch import Batch, BatchCategory +from context_use.thread_embedding.embedding import ( + store_thread_embeddings, + submit_thread_embeddings, +) +from context_use.thread_embedding.factory import ThreadEmbeddingBatchFactory +from context_use.thread_embedding.states import ( + ThreadEmbedCompleteState, + ThreadEmbedPendingState, +) + +logger = logging.getLogger(__name__) + + +@register_batch_manager(BatchCategory.thread_embedding) +class ThreadEmbeddingBatchManager(BaseBatchManager): + """Embeds thread content into vec_threads. + + State machine: + CREATED -> THREAD_EMBED_PENDING -> THREAD_EMBED_COMPLETE -> COMPLETE + """ + + def __init__(self, batch: Batch, ctx: BatchContext) -> None: + super().__init__(batch, ctx) + self.batch_factory = ThreadEmbeddingBatchFactory + + async def _transition(self, current_state: State) -> State | None: + match current_state: + case CreatedState(): + logger.info("[%s] Starting thread embedding", self.batch.id) + return await self._submit_embeddings() + + case ThreadEmbedPendingState() as state: + logger.info("[%s] Polling thread embedding status", self.batch.id) + return await self._check_embedding_results(state) + + case ThreadEmbedCompleteState(): + logger.info("[%s] Thread embedding complete", self.batch.id) + return CompleteState() + + case _: + raise ValueError( + f"Invalid state for thread_embedding batch: {current_state}" + ) + + async def _submit_embeddings(self) -> State: + groups = await self.batch_factory.get_batch_groups(self.batch, self.ctx.store) + threads = [t for g in groups for t in g.threads] + + embeddable = [t for t in threads if t.get_embeddable_content() is not None] + if not embeddable: + return SkippedState(reason="No threads with embeddable content") + + job_key, _ = await submit_thread_embeddings( + embeddable, self.batch.id, self.ctx.llm_client + ) + return ThreadEmbedPendingState(job_key=job_key) + + async def _check_embedding_results(self, state: ThreadEmbedPendingState) -> State: + results = await self.ctx.llm_client.embed_batch_get_results(state.job_key) + if results is None: + return state + + count = await store_thread_embeddings(results, self.batch.id, self.ctx.store) + return ThreadEmbedCompleteState(embedded_count=count) diff --git a/context_use/thread_embedding/states.py b/context_use/thread_embedding/states.py new file mode 100644 index 00000000..a057045b --- /dev/null +++ b/context_use/thread_embedding/states.py @@ -0,0 +1,72 @@ +# pyright: reportIncompatibleVariableOverride=false +# Literal field overrides are the standard Pydantic discriminated-union pattern. + +from __future__ import annotations + +import random +from datetime import datetime +from typing import Literal + +from pydantic import Field + +from context_use.batch.registry import register_batch_state_parser +from context_use.batch.states import ( + CompleteState, + CreatedState, + CurrentState, + FailedState, + NextState, + SkippedState, + State, + _utc_now, +) +from context_use.models.batch import BatchCategory + +EMBED_POLL_INTERVAL_SECS = 30 + + +class ThreadEmbedPendingState(CurrentState): + status: Literal["THREAD_EMBED_PENDING"] = "THREAD_EMBED_PENDING" + job_key: str + submitted_at: datetime = Field(default_factory=_utc_now) + + @property + def poll_next_countdown(self) -> int: + jitter = random.randint(-10, 10) + return max(0, EMBED_POLL_INTERVAL_SECS + jitter) + + +class ThreadEmbedCompleteState(NextState): + status: Literal["THREAD_EMBED_COMPLETE"] = "THREAD_EMBED_COMPLETE" + completed_at: datetime = Field(default_factory=_utc_now) + embedded_count: int = 0 + + +ThreadEmbeddingBatchState = ( + CreatedState + | ThreadEmbedPendingState + | ThreadEmbedCompleteState + | CompleteState + | SkippedState + | FailedState +) + +_STATE_MAP: dict[str, type[State]] = { + "CREATED": CreatedState, + "THREAD_EMBED_PENDING": ThreadEmbedPendingState, + "THREAD_EMBED_COMPLETE": ThreadEmbedCompleteState, + "COMPLETE": CompleteState, + "SKIPPED": SkippedState, + "FAILED": FailedState, +} + + +@register_batch_state_parser(BatchCategory.thread_embedding) +def parse_thread_embedding_batch_state(state_dict: dict) -> State: + status = state_dict.get("status") + if status is None: + raise ValueError("State dict missing 'status' key") + cls = _STATE_MAP.get(status) + if cls is None: + raise ValueError(f"Unknown thread_embedding batch state: {status}") + return cls.model_validate(state_dict) diff --git a/tests/unit/models/test_thread.py b/tests/unit/models/test_thread.py index d94b591e..1bb8e4e7 100644 --- a/tests/unit/models/test_thread.py +++ b/tests/unit/models/test_thread.py @@ -101,6 +101,55 @@ def test_caption_extracted_from_payload(self) -> None: assert thread.get_content() == "sunset at the beach" +class TestThreadGetEmbeddableContent: + def test_non_asset_returns_raw_content(self) -> None: + thread = Thread( + unique_key="k1", + provider="ChatGPT", + interaction_type="chatgpt_conversations", + payload=_make_send_message_payload(), + version="1.1.0", + asat=datetime(2025, 1, 1, tzinfo=UTC), + ) + assert thread.get_embeddable_content() == "hello world" + + def test_non_asset_returns_none_when_empty(self) -> None: + thread = Thread( + unique_key="k2", + provider="Instagram", + interaction_type="instagram_posts", + payload=_make_create_object_payload(), + version="1.1.0", + asat=datetime(2025, 1, 1, tzinfo=UTC), + ) + assert thread.get_embeddable_content() is None + + def test_asset_returns_enriched_content(self) -> None: + thread = Thread( + unique_key="k3", + provider="Instagram", + interaction_type="instagram_posts", + payload=_make_create_object_payload(), + version="1.1.0", + asat=datetime(2025, 1, 1, tzinfo=UTC), + asset_uri="archive/pic.jpg", + content="A sunset over the ocean", + ) + assert thread.get_embeddable_content() == "A sunset over the ocean" + + def test_asset_returns_none_when_not_described(self) -> None: + thread = Thread( + unique_key="k4", + provider="Instagram", + interaction_type="instagram_posts", + payload=_make_create_object_payload(caption="sunset"), + version="1.1.0", + asat=datetime(2025, 1, 1, tzinfo=UTC), + asset_uri="archive/pic.jpg", + ) + assert thread.get_embeddable_content() is None + + class TestThreadGetRawContent: def test_returns_payload_content_even_when_content_is_set(self) -> None: thread = Thread( diff --git a/tests/unit/store/test_sqlite.py b/tests/unit/store/test_sqlite.py index 8cef3d7d..d4233454 100644 --- a/tests/unit/store/test_sqlite.py +++ b/tests/unit/store/test_sqlite.py @@ -432,6 +432,89 @@ async def test_search_memories_by_embedding_with_date_filter( assert results[0].similarity is not None +async def test_upsert_thread_embedding(store: SqliteStore) -> None: + archive = Archive(provider="test") + await store.create_archive(archive) + task = EtlTask( + archive_id=archive.id, + provider="test", + interaction_type="test_type", + source_uris=["test.json"], + ) + await store.create_task(task) + rows = [ + ThreadRow( + unique_key="uk-embed-1", + provider="test", + interaction_type="test_type", + preview="p", + payload={ + "type": "Create", + "fibre_kind": "Create", + "object": {"type": "Note"}, + }, + version="1.0", + asat=datetime(2025, 1, 1, tzinfo=UTC), + ) + ] + ids = await store.insert_threads(rows, task.id) + thread_id = ids[0] + + embedding = [1.0, 0.0, 0.0, 0.0] + await store.upsert_thread_embedding(thread_id, embedding) + + db = await store._conn() + result = list( + await db.execute_fetchall( + "SELECT thread_id FROM vec_threads WHERE thread_id = ?", + (thread_id,), + ) + ) + assert len(result) == 1 + assert result[0][0] == thread_id + + +async def test_upsert_thread_embedding_replaces_existing(store: SqliteStore) -> None: + archive = Archive(provider="test") + await store.create_archive(archive) + task = EtlTask( + archive_id=archive.id, + provider="test", + interaction_type="test_type", + source_uris=["test.json"], + ) + await store.create_task(task) + rows = [ + ThreadRow( + unique_key="uk-embed-replace", + provider="test", + interaction_type="test_type", + preview="p", + payload={ + "type": "Create", + "fibre_kind": "Create", + "object": {"type": "Note"}, + }, + version="1.0", + asat=datetime(2025, 1, 1, tzinfo=UTC), + ) + ] + ids = await store.insert_threads(rows, task.id) + thread_id = ids[0] + + await store.upsert_thread_embedding(thread_id, [1.0, 0.0, 0.0, 0.0]) + await store.upsert_thread_embedding(thread_id, [0.0, 1.0, 0.0, 0.0]) + + db = await store._conn() + result = list( + await db.execute_fetchall( + "SELECT thread_id FROM vec_threads WHERE thread_id = ?", + (thread_id,), + ) + ) + assert len(result) == 1 + + async def test_atomic_commits_on_success(store: SqliteStore) -> None: async with store.atomic(): archive = Archive(provider="test") diff --git a/tests/unit/thread_embedding/__init__.py b/tests/unit/thread_embedding/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/thread_embedding/test_batch_creation.py b/tests/unit/thread_embedding/test_batch_creation.py new file mode 100644 index 00000000..15b9e987 --- /dev/null +++ b/tests/unit/thread_embedding/test_batch_creation.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from context_use.core import ContextUse +from context_use.models.thread import Thread + + +def _make_thread( + *, + thread_id: str = "t1", + asset_uri: str | None = None, + content: str | None = None, + caption: str | None = "hello world", +) -> Thread: + obj: dict = {"type": "Note", "fibre_kind": "TextMessage"} + if caption is not None: + obj["content"] = caption + return Thread( + id=thread_id, + unique_key=f"uk-{thread_id}", + provider="ChatGPT", + interaction_type="chatgpt_conversations", + payload={ + "type": "Create", + "fibre_kind": "SendMessage", + "object": obj, + "target": {"type": "Application", "name": "assistant"}, + }, + version="1.1.0", + asat=datetime(2025, 1, 1, tzinfo=UTC), + asset_uri=asset_uri, + content=content, + ) + + +def _make_ctx(*, threads: list[Thread]) -> ContextUse: + store = AsyncMock() + store.get_unprocessed_threads = AsyncMock(return_value=threads) + store.create_batch = AsyncMock(side_effect=lambda b, _groups: b) + + ctx = object.__new__(ContextUse) + ctx._store = store + ctx._llm_client = MagicMock() + ctx._storage = MagicMock() + return ctx + + +class TestCreateThreadEmbeddingBatches: + @pytest.mark.asyncio + async def test_creates_batches_for_embeddable_threads(self) -> None: + threads = [ + _make_thread(thread_id="t1", caption="hello"), + _make_thread(thread_id="t2", caption="world"), + ] + ctx = _make_ctx(threads=threads) + + batches = await ctx.create_thread_embedding_batches() + + assert len(batches) == 1 + + @pytest.mark.asyncio + async def test_skips_undescribed_asset_threads(self) -> None: + threads = [ + _make_thread(thread_id="text", caption="hello"), + _make_thread(thread_id="asset", asset_uri="pic.jpg", caption=None), + ] + ctx = _make_ctx(threads=threads) + + with patch( + "context_use.thread_embedding.factory.ThreadEmbeddingBatchFactory.create_batches", + new_callable=AsyncMock, + return_value=[], + ) as mock_create: + await ctx.create_thread_embedding_batches() + groups = mock_create.call_args[0][0] + + group_ids = [g.group_id for g in groups] + assert "text" in group_ids + assert "asset" not in group_ids + + @pytest.mark.asyncio + async def test_includes_described_asset_threads(self) -> None: + threads = [ + _make_thread( + thread_id="asset", + asset_uri="pic.jpg", + content="A sunset", + caption=None, + ), + ] + ctx = _make_ctx(threads=threads) + + batches = await ctx.create_thread_embedding_batches() + + assert len(batches) == 1 + + @pytest.mark.asyncio + async def test_returns_empty_when_no_embeddable(self) -> None: + threads = [ + _make_thread(thread_id="asset", asset_uri="pic.jpg", caption=None), + ] + ctx = _make_ctx(threads=threads) + + batches = await ctx.create_thread_embedding_batches() + assert batches == [] + + @pytest.mark.asyncio + async def test_returns_empty_when_no_threads(self) -> None: + ctx = _make_ctx(threads=[]) + batches = await ctx.create_thread_embedding_batches() + assert batches == [] + + @pytest.mark.asyncio + async def test_forwards_task_id_to_store(self) -> None: + ctx = _make_ctx(threads=[]) + await ctx.create_thread_embedding_batches(task_id="task-42") + + mock: AsyncMock = ctx._store.get_unprocessed_threads # type: ignore[assignment] + mock.assert_awaited_once() + assert mock.call_args.kwargs["task_id"] == "task-42" + + @pytest.mark.asyncio + async def test_forwards_since_to_store(self) -> None: + since = datetime(2025, 6, 1, tzinfo=UTC) + ctx = _make_ctx(threads=[]) + await ctx.create_thread_embedding_batches(since=since) + + mock: AsyncMock = ctx._store.get_unprocessed_threads # type: ignore[assignment] + assert mock.call_args.kwargs["since"] == since diff --git a/tests/unit/thread_embedding/test_embedding.py b/tests/unit/thread_embedding/test_embedding.py new file mode 100644 index 00000000..3540e3f7 --- /dev/null +++ b/tests/unit/thread_embedding/test_embedding.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from unittest.mock import AsyncMock + +import pytest + +from context_use.llm.base import EmbedItem +from context_use.models.thread import Thread +from context_use.thread_embedding.embedding import ( + store_thread_embeddings, + submit_thread_embeddings, +) + + +def _make_thread( + *, + thread_id: str = "t1", + asset_uri: str | None = None, + content: str | None = None, + caption: str | None = "hello world", +) -> Thread: + obj: dict = {"type": "Note", "fibre_kind": "TextMessage"} + if caption is not None: + obj["content"] = caption + payload: dict = { + "type": "Create", + "fibre_kind": "SendMessage", + "object": obj, + "target": {"type": "Application", "name": "assistant"}, + } + return Thread( + id=thread_id, + unique_key=f"uk-{thread_id}", + provider="ChatGPT", + interaction_type="chatgpt_conversations", + payload=payload, + version="1.1.0", + asat=datetime(2025, 1, 1, tzinfo=UTC), + asset_uri=asset_uri, + content=content, + ) + + +class TestSubmitThreadEmbeddings: + @pytest.mark.asyncio + async def test_submits_embeddable_threads(self) -> None: + threads = [_make_thread(thread_id="t1"), _make_thread(thread_id="t2")] + llm = AsyncMock() + llm.embed_batch_submit = AsyncMock(return_value="job-1") + + job_key, ids = await submit_thread_embeddings(threads, "batch-1", llm) + + assert job_key == "job-1" + assert ids == ["t1", "t2"] + items = llm.embed_batch_submit.call_args[0][1] + assert len(items) == 2 + assert all(isinstance(i, EmbedItem) for i in items) + + @pytest.mark.asyncio + async def test_skips_threads_without_embeddable_content(self) -> None: + threads = [ + _make_thread(thread_id="t1", caption="hello"), + _make_thread(thread_id="t2", caption=None, asset_uri="pic.jpg"), + ] + llm = AsyncMock() + llm.embed_batch_submit = AsyncMock(return_value="job-1") + + job_key, ids = await submit_thread_embeddings(threads, "batch-1", llm) + + assert ids == ["t1"] + + @pytest.mark.asyncio + async def test_raises_when_no_embeddable_content(self) -> None: + threads = [ + _make_thread(thread_id="t1", caption=None, asset_uri="pic.jpg"), + ] + llm = AsyncMock() + + with pytest.raises(ValueError, match="No threads with embeddable content"): + await submit_thread_embeddings(threads, "batch-1", llm) + + +class TestStoreThreadEmbeddings: + @pytest.mark.asyncio + async def test_stores_all_results(self) -> None: + store = AsyncMock() + results = {"t1": [1.0, 0.0], "t2": [0.0, 1.0]} + + count = await store_thread_embeddings(results, "batch-1", store) + + assert count == 2 + assert store.upsert_thread_embedding.await_count == 2 + + @pytest.mark.asyncio + async def test_returns_zero_for_empty_results(self) -> None: + store = AsyncMock() + count = await store_thread_embeddings({}, "batch-1", store) + assert count == 0 diff --git a/tests/unit/thread_embedding/test_manager.py b/tests/unit/thread_embedding/test_manager.py new file mode 100644 index 00000000..d151d0c7 --- /dev/null +++ b/tests/unit/thread_embedding/test_manager.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from context_use.batch.grouper import ThreadGroup +from context_use.batch.manager import BatchContext +from context_use.batch.states import CompleteState, CreatedState, SkippedState +from context_use.models.batch import Batch, BatchCategory +from context_use.models.thread import Thread +from context_use.thread_embedding.manager import ThreadEmbeddingBatchManager +from context_use.thread_embedding.states import ( + ThreadEmbedCompleteState, + ThreadEmbedPendingState, +) + + +def _make_batch() -> Batch: + return Batch( + batch_number=1, + category=BatchCategory.thread_embedding.value, + states=[CreatedState().model_dump(mode="json")], + ) + + +def _make_thread( + *, + thread_id: str = "t1", + caption: str = "hello world", + asset_uri: str | None = None, + content: str | None = None, +) -> Thread: + obj: dict = {"type": "Note", "fibre_kind": "TextMessage", "content": caption} + return Thread( + id=thread_id, + unique_key=f"uk-{thread_id}", + provider="ChatGPT", + interaction_type="chatgpt_conversations", + payload={ + "type": "Create", + "fibre_kind": "SendMessage", + "object": obj, + "target": {"type": "Application", "name": "assistant"}, + }, + version="1.1.0", + asat=datetime(2025, 1, 1, tzinfo=UTC), + asset_uri=asset_uri, + content=content, + ) + + +def _make_ctx() -> BatchContext: + return BatchContext( + store=AsyncMock(), + llm_client=AsyncMock(), + storage=MagicMock(), + ) + + +class TestThreadEmbeddingBatchManager: + @pytest.mark.asyncio + async def test_full_state_machine(self) -> None: + batch = _make_batch() + thread = _make_thread() + ctx = _make_ctx() + ctx.llm_client.embed_batch_submit = AsyncMock(return_value="embed-job-1") + ctx.llm_client.embed_batch_get_results = AsyncMock( + return_value={thread.id: [1.0, 0.0, 0.0]} + ) + + manager = ThreadEmbeddingBatchManager(batch, ctx) + + with patch.object( + manager.batch_factory, + "get_batch_groups", + return_value=[ThreadGroup(threads=[thread], group_id=thread.id)], + ): + state = await manager._transition(CreatedState()) + assert isinstance(state, ThreadEmbedPendingState) + assert state.job_key == "embed-job-1" + + state2 = await manager._transition(state) + assert isinstance(state2, ThreadEmbedCompleteState) + assert state2.embedded_count == 1 + + state3 = await manager._transition(state2) + assert isinstance(state3, CompleteState) + + @pytest.mark.asyncio + async def test_polls_while_results_none(self) -> None: + batch = _make_batch() + ctx = _make_ctx() + ctx.llm_client.embed_batch_get_results = AsyncMock(return_value=None) + + manager = ThreadEmbeddingBatchManager(batch, ctx) + pending = ThreadEmbedPendingState(job_key="job-1") + + state = await manager._transition(pending) + assert isinstance(state, ThreadEmbedPendingState) + assert state.job_key == "job-1" + + @pytest.mark.asyncio + async def test_skip_when_no_embeddable_threads(self) -> None: + batch = _make_batch() + ctx = _make_ctx() + # Asset thread with no description -> get_embeddable_content() returns None + thread = _make_thread(asset_uri="pic.jpg") + manager = ThreadEmbeddingBatchManager(batch, ctx) + + with patch.object( + manager.batch_factory, + "get_batch_groups", + return_value=[ThreadGroup(threads=[thread], group_id=thread.id)], + ): + state = await manager._transition(CreatedState()) + assert isinstance(state, SkippedState) + + @pytest.mark.asyncio + async def test_skip_when_no_groups(self) -> None: + batch = _make_batch() + ctx = _make_ctx() + manager = ThreadEmbeddingBatchManager(batch, ctx) + + with patch.object( + manager.batch_factory, + "get_batch_groups", + return_value=[], + ): + state = await manager._transition(CreatedState()) + assert isinstance(state, SkippedState) + + @pytest.mark.asyncio + async def test_mixed_threads_skips_undescribed_assets(self) -> None: + batch = _make_batch() + ctx = _make_ctx() + ctx.llm_client.embed_batch_submit = AsyncMock(return_value="embed-job-2") + + text_thread = _make_thread(thread_id="text-1", caption="some text") + asset_thread = _make_thread(thread_id="asset-1", asset_uri="pic.jpg") + described_asset = _make_thread( + thread_id="asset-2", + asset_uri="pic2.jpg", + content="A beautiful sunset", + ) + + groups = [ + ThreadGroup(threads=[text_thread], group_id=text_thread.id), + ThreadGroup(threads=[asset_thread], group_id=asset_thread.id), + ThreadGroup(threads=[described_asset], group_id=described_asset.id), + ] + + manager = ThreadEmbeddingBatchManager(batch, ctx) + + with patch.object( + manager.batch_factory, + "get_batch_groups", + return_value=groups, + ): + state = await manager._transition(CreatedState()) + assert isinstance(state, ThreadEmbedPendingState) + + # Verify only 2 threads were submitted (text + described asset) + items = ctx.llm_client.embed_batch_submit.call_args[0][1] + submitted_ids = [item.item_id for item in items] + assert "text-1" in submitted_ids + assert "asset-2" in submitted_ids + assert "asset-1" not in submitted_ids diff --git a/tests/unit/thread_embedding/test_states.py b/tests/unit/thread_embedding/test_states.py new file mode 100644 index 00000000..ba7088a7 --- /dev/null +++ b/tests/unit/thread_embedding/test_states.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from context_use.batch.states import ( + CompleteState, + CreatedState, + FailedState, + SkippedState, +) +from context_use.thread_embedding.states import ( + ThreadEmbedCompleteState, + ThreadEmbedPendingState, + parse_thread_embedding_batch_state, +) + + +class TestParseThreadEmbeddingBatchState: + def test_created(self) -> None: + state = parse_thread_embedding_batch_state({"status": "CREATED"}) + assert isinstance(state, CreatedState) + + def test_embed_pending(self) -> None: + state = parse_thread_embedding_batch_state( + {"status": "THREAD_EMBED_PENDING", "job_key": "job-1"} + ) + assert isinstance(state, ThreadEmbedPendingState) + assert state.job_key == "job-1" + + def test_embed_complete(self) -> None: + state = parse_thread_embedding_batch_state( + {"status": "THREAD_EMBED_COMPLETE", "embedded_count": 5} + ) + assert isinstance(state, ThreadEmbedCompleteState) + assert state.embedded_count == 5 + + def test_complete(self) -> None: + state = parse_thread_embedding_batch_state({"status": "COMPLETE"}) + assert isinstance(state, CompleteState) + + def test_skipped(self) -> None: + state = parse_thread_embedding_batch_state( + {"status": "SKIPPED", "reason": "empty"} + ) + assert isinstance(state, SkippedState) + + def test_failed(self) -> None: + state = parse_thread_embedding_batch_state( + {"status": "FAILED", "error_message": "boom", "previous_status": "CREATED"} + ) + assert isinstance(state, FailedState) + + def test_unknown_raises(self) -> None: + import pytest + + with pytest.raises(ValueError, match="Unknown thread_embedding"): + parse_thread_embedding_batch_state({"status": "BOGUS"}) + + +class TestThreadEmbedPendingState: + def test_poll_countdown_is_non_negative(self) -> None: + state = ThreadEmbedPendingState(job_key="job-1") + assert state.poll_next_countdown >= 0 + + def test_round_trip(self) -> None: + state = ThreadEmbedPendingState(job_key="job-1") + dumped = state.model_dump(mode="json") + restored = parse_thread_embedding_batch_state(dumped) + assert isinstance(restored, ThreadEmbedPendingState) + assert restored.job_key == "job-1" From f0d566c71c3e7e5a831c54c89e28287858eb3f6a Mon Sep 17 00:00:00 2001 From: Alex Kulikov Date: Thu, 16 Apr 2026 15:07:07 +0100 Subject: [PATCH 2/2] feat: add embedding search methods --- context_use/core.py | 17 ++- context_use/store/base.py | 26 +++++ context_use/store/sqlite/store.py | 53 +++++++++- tests/unit/store/test_sqlite.py | 115 +++++++++++++++++++++ tests/unit/thread_embedding/test_search.py | 69 +++++++++++++ 5 files changed, 278 insertions(+), 2 deletions(-) create mode 100644 tests/unit/thread_embedding/test_search.py diff --git a/context_use/core.py b/context_use/core.py index a614bb8f..50a45fe8 100644 --- a/context_use/core.py +++ b/context_use/core.py @@ -28,7 +28,7 @@ get_memory_config, get_memory_interaction_types, ) -from context_use.store.base import MemorySearchResult +from context_use.store.base import MemorySearchResult, ThreadSearchResult from context_use.types import PipelineResult, TaskBreakdown if TYPE_CHECKING: @@ -442,6 +442,21 @@ async def insert_threads( """Insert thread rows into the store, deduplicating on ``unique_key``.""" return await self._store.insert_threads(rows, task_id) + async def search_threads( + self, + query: str, + *, + top_k: int = 10, + interaction_types: list[str] | None = None, + ) -> list[ThreadSearchResult]: + """Search threads by semantic similarity.""" + query_embedding = await self._llm_client.embed_query(query) + return await self._store.search_threads( + query_embedding=query_embedding, + top_k=top_k, + interaction_types=interaction_types, + ) + # ── Private helpers ────────────────────────────────────────────── def _batch_context(self) -> BatchContext: diff --git a/context_use/store/base.py b/context_use/store/base.py index b8236610..b161692a 100644 --- a/context_use/store/base.py +++ b/context_use/store/base.py @@ -37,6 +37,17 @@ class MemorySearchResult: similarity: float | None +@dataclass(frozen=True) +class ThreadSearchResult: + """A thread search hit with similarity score.""" + + id: str + interaction_type: str + content: str + asat: datetime + similarity: float + + class Store(ABC): """Abstract store for all context_use domain entities. @@ -265,6 +276,21 @@ async def upsert_thread_embedding( """Insert or replace the embedding vector for a thread.""" ... + @abstractmethod + async def search_threads( + self, + *, + query_embedding: list[float], + top_k: int = 10, + interaction_types: list[str] | None = None, + ) -> list[ThreadSearchResult]: + """Search threads by semantic similarity. + + If *interaction_types* is given, only threads whose + ``interaction_type`` is in that list are returned. + """ + ... + # ── Memory Facets ──────────────────────────────────────────────── @abstractmethod diff --git a/context_use/store/sqlite/store.py b/context_use/store/sqlite/store.py index f8766f8b..773d9ece 100644 --- a/context_use/store/sqlite/store.py +++ b/context_use/store/sqlite/store.py @@ -22,7 +22,12 @@ Thread, ) from context_use.models.utils import generate_uuidv4 -from context_use.store.base import MemorySearchResult, SortOrder, Store +from context_use.store.base import ( + MemorySearchResult, + SortOrder, + Store, + ThreadSearchResult, +) from context_use.store.sqlite.schema import ( ArchiveRow, BatchRow, @@ -665,6 +670,52 @@ async def upsert_thread_embedding( ) await self._commit_unless_atomic() + async def search_threads( + self, + *, + query_embedding: list[float], + top_k: int = 10, + interaction_types: list[str] | None = None, + ) -> list[ThreadSearchResult]: + db = await self._conn() + vec_rows = await db.execute_fetchall( + "SELECT thread_id, distance FROM vec_threads " + "WHERE embedding MATCH ? AND k = ?", + (VecThreadRow.serialize(query_embedding), top_k * 4), + ) + if not vec_rows: + return [] + + candidate_ids = [r[0] for r in vec_rows] + distances: dict[str, float] = {r[0]: r[1] for r in vec_rows} + + ph = ",".join("?" for _ in candidate_ids) + sql = ( + "SELECT id, interaction_type, content, asat " + f"FROM threads WHERE id IN ({ph})" + ) + params: list = list(candidate_ids) + if interaction_types is not None: + type_ph = ",".join("?" for _ in interaction_types) + sql += f" AND interaction_type IN ({type_ph})" + params.extend(interaction_types) + + thread_rows = await db.execute_fetchall(sql, params) + + results = [ + ThreadSearchResult( + id=r["id"], + interaction_type=r["interaction_type"], + content=r["content"] or "", + asat=parse_dt(r["asat"]), + similarity=1.0 - distances[r["id"]], + ) + for r in thread_rows + if r["id"] in distances + ] + results.sort(key=lambda x: x.similarity, reverse=True) + return results[:top_k] + async def create_memory_facet(self, facet: MemoryFacet) -> MemoryFacet: db = await self._conn() await db.execute( diff --git a/tests/unit/store/test_sqlite.py b/tests/unit/store/test_sqlite.py index d4233454..6f4ce7b5 100644 --- a/tests/unit/store/test_sqlite.py +++ b/tests/unit/store/test_sqlite.py @@ -515,6 +515,121 @@ async def test_upsert_thread_embedding_replaces_existing(store: SqliteStore) -> assert len(result) == 1 +async def _insert_thread_with_embedding( + store: SqliteStore, + *, + unique_key: str, + interaction_type: str = "test_type", + content: str | None = None, + embedding: list[float], +) -> str: + """Helper: insert a thread and its embedding, return the thread ID.""" + archive = Archive(provider="test") + await store.create_archive(archive) + task = EtlTask( + archive_id=archive.id, + provider="test", + interaction_type=interaction_type, + source_uris=["test.json"], + ) + await store.create_task(task) + rows = [ + ThreadRow( + unique_key=unique_key, + provider="test", + interaction_type=interaction_type, + preview="p", + payload={ + "type": "Create", + "fibre_kind": "Create", + "object": {"type": "Note"}, + }, + version="1.0", + asat=datetime(2025, 1, 1, tzinfo=UTC), + ) + ] + ids = await store.insert_threads(rows, task.id) + thread_id = ids[0] + if content is not None: + await store.update_thread_content(thread_id, content) + await store.upsert_thread_embedding(thread_id, embedding) + return thread_id + + +async def test_search_threads_by_embedding(store: SqliteStore) -> None: + emb_similar = [1.0, 0.0, 0.0, 0.0] + emb_different = [0.0, 0.0, 0.0, 1.0] + + id_similar = await _insert_thread_with_embedding( + store, + unique_key="uk-similar", + content="similar content", + embedding=emb_similar, + ) + await _insert_thread_with_embedding( + store, + unique_key="uk-different", + content="different content", + embedding=emb_different, + ) + + query_emb = [0.9, 0.1, 0.0, 0.0] + results = await store.search_threads(query_embedding=query_emb, top_k=1) + + assert len(results) == 1 + assert results[0].id == id_similar + assert results[0].content == "similar content" + assert results[0].similarity > 0.0 + + +async def test_search_threads_filters_by_interaction_type(store: SqliteStore) -> None: + emb = [1.0, 0.0, 0.0, 0.0] + + await _insert_thread_with_embedding( + store, + unique_key="uk-type-a", + interaction_type="type_a", + content="content a", + embedding=emb, + ) + id_b = await _insert_thread_with_embedding( + store, + unique_key="uk-type-b", + interaction_type="type_b", + content="content b", + embedding=emb, + ) + + results = await store.search_threads( + query_embedding=emb, + top_k=10, + interaction_types=["type_b"], + ) + + assert len(results) == 1 + assert results[0].id == id_b + + +async def test_search_threads_returns_empty_when_no_embeddings( + store: SqliteStore, +) -> None: + results = await store.search_threads(query_embedding=[1.0, 0.0, 0.0, 0.0], top_k=5) + assert results == [] + + +async def test_search_threads_respects_top_k(store: SqliteStore) -> None: + for i in range(5): + await _insert_thread_with_embedding( + store, + unique_key=f"uk-topk-{i}", + content=f"content {i}", + embedding=[1.0, 0.0, 0.0, 0.0], + ) + + results = await store.search_threads(query_embedding=[1.0, 0.0, 0.0, 0.0], top_k=3) + assert len(results) == 3 + + async def test_atomic_commits_on_success(store: SqliteStore) -> None: async with store.atomic(): archive = Archive(provider="test") diff --git a/tests/unit/thread_embedding/test_search.py b/tests/unit/thread_embedding/test_search.py new file mode 100644 index 00000000..1b0d695f --- /dev/null +++ b/tests/unit/thread_embedding/test_search.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from context_use.core import ContextUse +from context_use.store.base import ThreadSearchResult + + +def _make_ctx(*, search_results: list[ThreadSearchResult]) -> ContextUse: + store = AsyncMock() + store.search_threads = AsyncMock(return_value=search_results) + + llm_client = MagicMock() + llm_client.embed_query = AsyncMock(return_value=[1.0, 0.0, 0.0]) + + ctx = object.__new__(ContextUse) + ctx._store = store + ctx._llm_client = llm_client + ctx._storage = MagicMock() + return ctx + + +class TestSearchThreads: + @pytest.mark.asyncio + async def test_embeds_query_and_delegates_to_store(self) -> None: + from datetime import UTC, datetime + + hit = ThreadSearchResult( + id="t1", + interaction_type="chatgpt_conversations", + content="hello world", + asat=datetime(2025, 1, 1, tzinfo=UTC), + similarity=0.95, + ) + ctx = _make_ctx(search_results=[hit]) + + results = await ctx.search_threads("hello") + + assert len(results) == 1 + assert results[0].id == "t1" + assert results[0].similarity == 0.95 + + ctx._llm_client.embed_query.assert_awaited_once_with("hello") # type: ignore[union-attr] + ctx._store.search_threads.assert_awaited_once() # type: ignore[union-attr] + call_kwargs = ctx._store.search_threads.call_args.kwargs # type: ignore[union-attr] + assert call_kwargs["query_embedding"] == [1.0, 0.0, 0.0] + assert call_kwargs["top_k"] == 10 + + @pytest.mark.asyncio + async def test_forwards_top_k_and_interaction_types(self) -> None: + ctx = _make_ctx(search_results=[]) + + await ctx.search_threads( + "query", + top_k=5, + interaction_types=["instagram_posts"], + ) + + call_kwargs = ctx._store.search_threads.call_args.kwargs # type: ignore[union-attr] + assert call_kwargs["top_k"] == 5 + assert call_kwargs["interaction_types"] == ["instagram_posts"] + + @pytest.mark.asyncio + async def test_returns_empty_when_no_results(self) -> None: + ctx = _make_ctx(search_results=[]) + results = await ctx.search_threads("nothing") + assert results == []