Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion context_use/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
get_memory_config,
get_memory_interaction_types,
)
from context_use.store.base import MemorySearchResult
from context_use.store.base import MemorySearchResult, ThreadSearchResult
from context_use.types import PipelineResult, TaskBreakdown

if TYPE_CHECKING:
Expand Down Expand Up @@ -442,6 +442,21 @@ async def insert_threads(
"""Insert thread rows into the store, deduplicating on ``unique_key``."""
return await self._store.insert_threads(rows, task_id)

async def search_threads(
self,
query: str,
*,
top_k: int = 10,
interaction_types: list[str] | None = None,
) -> list[ThreadSearchResult]:
"""Search threads by semantic similarity."""
query_embedding = await self._llm_client.embed_query(query)
return await self._store.search_threads(
query_embedding=query_embedding,
top_k=top_k,
interaction_types=interaction_types,
)

# ── Private helpers ──────────────────────────────────────────────

def _batch_context(self) -> BatchContext:
Expand Down
26 changes: 26 additions & 0 deletions context_use/store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ class MemorySearchResult:
similarity: float | None


@dataclass(frozen=True)
class ThreadSearchResult:
"""A thread search hit with similarity score."""

id: str
interaction_type: str
content: str
asat: datetime
similarity: float


class Store(ABC):
"""Abstract store for all context_use domain entities.

Expand Down Expand Up @@ -265,6 +276,21 @@ async def upsert_thread_embedding(
"""Insert or replace the embedding vector for a thread."""
...

@abstractmethod
async def search_threads(
self,
*,
query_embedding: list[float],
top_k: int = 10,
interaction_types: list[str] | None = None,
) -> list[ThreadSearchResult]:
"""Search threads by semantic similarity.

If *interaction_types* is given, only threads whose
``interaction_type`` is in that list are returned.
"""
...

# ── Memory Facets ────────────────────────────────────────────────

@abstractmethod
Expand Down
53 changes: 52 additions & 1 deletion context_use/store/sqlite/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
Thread,
)
from context_use.models.utils import generate_uuidv4
from context_use.store.base import MemorySearchResult, SortOrder, Store
from context_use.store.base import (
MemorySearchResult,
SortOrder,
Store,
ThreadSearchResult,
)
from context_use.store.sqlite.schema import (
ArchiveRow,
BatchRow,
Expand Down Expand Up @@ -665,6 +670,52 @@ async def upsert_thread_embedding(
)
await self._commit_unless_atomic()

async def search_threads(
self,
*,
query_embedding: list[float],
top_k: int = 10,
interaction_types: list[str] | None = None,
) -> list[ThreadSearchResult]:
db = await self._conn()
vec_rows = await db.execute_fetchall(
"SELECT thread_id, distance FROM vec_threads "
"WHERE embedding MATCH ? AND k = ?",
(VecThreadRow.serialize(query_embedding), top_k * 4),
)
if not vec_rows:
return []

candidate_ids = [r[0] for r in vec_rows]
distances: dict[str, float] = {r[0]: r[1] for r in vec_rows}

ph = ",".join("?" for _ in candidate_ids)
sql = (
"SELECT id, interaction_type, content, asat "
f"FROM threads WHERE id IN ({ph})"
)
params: list = list(candidate_ids)
if interaction_types is not None:
type_ph = ",".join("?" for _ in interaction_types)
sql += f" AND interaction_type IN ({type_ph})"
params.extend(interaction_types)

thread_rows = await db.execute_fetchall(sql, params)

results = [
ThreadSearchResult(
id=r["id"],
interaction_type=r["interaction_type"],
content=r["content"] or "",
asat=parse_dt(r["asat"]),
similarity=1.0 - distances[r["id"]],
)
for r in thread_rows
if r["id"] in distances
]
results.sort(key=lambda x: x.similarity, reverse=True)
return results[:top_k]

async def create_memory_facet(self, facet: MemoryFacet) -> MemoryFacet:
db = await self._conn()
await db.execute(
Expand Down
115 changes: 115 additions & 0 deletions tests/unit/store/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,121 @@ async def test_upsert_thread_embedding_replaces_existing(store: SqliteStore) ->
assert len(result) == 1


async def _insert_thread_with_embedding(
store: SqliteStore,
*,
unique_key: str,
interaction_type: str = "test_type",
content: str | None = None,
embedding: list[float],
) -> str:
"""Helper: insert a thread and its embedding, return the thread ID."""
archive = Archive(provider="test")
await store.create_archive(archive)
task = EtlTask(
archive_id=archive.id,
provider="test",
interaction_type=interaction_type,
source_uris=["test.json"],
)
await store.create_task(task)
rows = [
ThreadRow(
unique_key=unique_key,
provider="test",
interaction_type=interaction_type,
preview="p",
payload={
"type": "Create",
"fibre_kind": "Create",
"object": {"type": "Note"},
},
version="1.0",
asat=datetime(2025, 1, 1, tzinfo=UTC),
)
]
ids = await store.insert_threads(rows, task.id)
thread_id = ids[0]
if content is not None:
await store.update_thread_content(thread_id, content)
await store.upsert_thread_embedding(thread_id, embedding)
return thread_id


async def test_search_threads_by_embedding(store: SqliteStore) -> None:
emb_similar = [1.0, 0.0, 0.0, 0.0]
emb_different = [0.0, 0.0, 0.0, 1.0]

id_similar = await _insert_thread_with_embedding(
store,
unique_key="uk-similar",
content="similar content",
embedding=emb_similar,
)
await _insert_thread_with_embedding(
store,
unique_key="uk-different",
content="different content",
embedding=emb_different,
)

query_emb = [0.9, 0.1, 0.0, 0.0]
results = await store.search_threads(query_embedding=query_emb, top_k=1)

assert len(results) == 1
assert results[0].id == id_similar
assert results[0].content == "similar content"
assert results[0].similarity > 0.0


async def test_search_threads_filters_by_interaction_type(store: SqliteStore) -> None:
emb = [1.0, 0.0, 0.0, 0.0]

await _insert_thread_with_embedding(
store,
unique_key="uk-type-a",
interaction_type="type_a",
content="content a",
embedding=emb,
)
id_b = await _insert_thread_with_embedding(
store,
unique_key="uk-type-b",
interaction_type="type_b",
content="content b",
embedding=emb,
)

results = await store.search_threads(
query_embedding=emb,
top_k=10,
interaction_types=["type_b"],
)

assert len(results) == 1
assert results[0].id == id_b


async def test_search_threads_returns_empty_when_no_embeddings(
store: SqliteStore,
) -> None:
results = await store.search_threads(query_embedding=[1.0, 0.0, 0.0, 0.0], top_k=5)
assert results == []


async def test_search_threads_respects_top_k(store: SqliteStore) -> None:
for i in range(5):
await _insert_thread_with_embedding(
store,
unique_key=f"uk-topk-{i}",
content=f"content {i}",
embedding=[1.0, 0.0, 0.0, 0.0],
)

results = await store.search_threads(query_embedding=[1.0, 0.0, 0.0, 0.0], top_k=3)
assert len(results) == 3


async def test_atomic_commits_on_success(store: SqliteStore) -> None:
async with store.atomic():
archive = Archive(provider="test")
Expand Down
69 changes: 69 additions & 0 deletions tests/unit/thread_embedding/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock

import pytest

from context_use.core import ContextUse
from context_use.store.base import ThreadSearchResult


def _make_ctx(*, search_results: list[ThreadSearchResult]) -> ContextUse:
store = AsyncMock()
store.search_threads = AsyncMock(return_value=search_results)

llm_client = MagicMock()
llm_client.embed_query = AsyncMock(return_value=[1.0, 0.0, 0.0])

ctx = object.__new__(ContextUse)
ctx._store = store
ctx._llm_client = llm_client
ctx._storage = MagicMock()
return ctx


class TestSearchThreads:
@pytest.mark.asyncio
async def test_embeds_query_and_delegates_to_store(self) -> None:
from datetime import UTC, datetime

hit = ThreadSearchResult(
id="t1",
interaction_type="chatgpt_conversations",
content="hello world",
asat=datetime(2025, 1, 1, tzinfo=UTC),
similarity=0.95,
)
ctx = _make_ctx(search_results=[hit])

results = await ctx.search_threads("hello")

assert len(results) == 1
assert results[0].id == "t1"
assert results[0].similarity == 0.95

ctx._llm_client.embed_query.assert_awaited_once_with("hello") # type: ignore[union-attr]
ctx._store.search_threads.assert_awaited_once() # type: ignore[union-attr]
call_kwargs = ctx._store.search_threads.call_args.kwargs # type: ignore[union-attr]
assert call_kwargs["query_embedding"] == [1.0, 0.0, 0.0]
assert call_kwargs["top_k"] == 10

@pytest.mark.asyncio
async def test_forwards_top_k_and_interaction_types(self) -> None:
ctx = _make_ctx(search_results=[])

await ctx.search_threads(
"query",
top_k=5,
interaction_types=["instagram_posts"],
)

call_kwargs = ctx._store.search_threads.call_args.kwargs # type: ignore[union-attr]
assert call_kwargs["top_k"] == 5
assert call_kwargs["interaction_types"] == ["instagram_posts"]

@pytest.mark.asyncio
async def test_returns_empty_when_no_results(self) -> None:
ctx = _make_ctx(search_results=[])
results = await ctx.search_threads("nothing")
assert results == []
Loading