From a01a44b0dd622cdcaf60825eea78c3700761c953 Mon Sep 17 00:00:00 2001 From: svalench Date: Fri, 8 May 2026 07:09:25 +0000 Subject: [PATCH] feat(conversational): session-aware conversational search endpoint Adds Sprint 3: an optional conversational search layer on top of the existing Searcher / LangGraph pipeline. Highlights: * New CONVERSATIONAL settings section with safe defaults (disabled). * New memory/ subpackage: - BaseMemoryBackend contract + ConversationEvent dataclass. - InMemoryBackend (process-local, thread-safe, bounded deque). - DjangoCacheBackend (works with any Django cache, including Redis via django-redis, without taking a hard Redis dependency). - build_memory_backend factory with short aliases (inmemory, cache, redis) and dotted-path support. * New langgraph_conversation module with the conversational graph: load_context -> interpret_followup -> maybe_clarify -> [execute_search] -> store_context, plus an in-tree fallback runner. * Follow-up interpretation is conservative on purpose: 'more', 'similar', 'only X' patterns reuse the previous turn; otherwise the query is left untouched. Ambiguous short follow-ups produce a structured clarification_needed flag rather than a hallucinated query. * Top results are persisted to memory in a compact form (model + pk + score only) so backends remain serialisation-friendly. * New ConversationalSearchAPIView mounted at POST /api/search/conversation/ (and DELETE for clearing history). Returns 404 when CONVERSATIONAL.ENABLED is False so the URL is safe to leave registered globally. * 16 new tests cover memory backends, individual nodes, end-to-end graph behaviour and the HTTP view. Backwards compatibility: * All previous tests still pass unchanged (36 passed total). * Existing /api/search/ and /api/search/similar/ endpoints are untouched. --- README.md | 57 +++ .../langgraph_conversation.py | 269 ++++++++++++++ src/django_graph_search/memory/__init__.py | 20 ++ src/django_graph_search/memory/base.py | 65 ++++ .../memory/django_cache.py | 54 +++ src/django_graph_search/memory/factory.py | 47 +++ src/django_graph_search/memory/in_memory.py | 41 +++ src/django_graph_search/settings.py | 53 +++ src/django_graph_search/urls.py | 7 +- src/django_graph_search/views.py | 147 ++++++++ tests/test_conversational_search.py | 331 ++++++++++++++++++ 11 files changed, 1090 insertions(+), 1 deletion(-) create mode 100644 src/django_graph_search/langgraph_conversation.py create mode 100644 src/django_graph_search/memory/__init__.py create mode 100644 src/django_graph_search/memory/base.py create mode 100644 src/django_graph_search/memory/django_cache.py create mode 100644 src/django_graph_search/memory/factory.py create mode 100644 src/django_graph_search/memory/in_memory.py create mode 100644 tests/test_conversational_search.py diff --git a/README.md b/README.md index eb72b43..f6cd9fe 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,8 @@ similar = get_similar(product_instance, limit=5) |---|---|---| | `/api/search/?q=...&models=...&limit=...` | `GET` | Semantic full-text search | | `/api/search/similar/{app}.{Model}/{id}/` | `GET` | Find similar objects | +| `/api/search/conversation/` | `POST` | Session-aware conversational search (optional, see below) | +| `/api/search/conversation/?conversation_id=...` | `DELETE` | Clear a conversation history | ## Management Commands @@ -251,6 +253,61 @@ If `langgraph` is not installed, the pipeline transparently uses an in-tree sequential runner with the same node structure, so behaviour and tests stay identical. +## Conversational search (optional) + +For session-aware semantic search (follow-ups like "more", "only products", +"similar") enable the conversational endpoint. It is a thin search-first +shell on top of `Searcher` and never invents user intent: ambiguous +follow-ups are surfaced as a structured `clarification_needed` flag instead +of a hallucinated query. + +```python +GRAPH_SEARCH = { + # ... existing config ... + "CONVERSATIONAL": { + "ENABLED": True, + "MEMORY_BACKEND": "inmemory", # or "cache" / dotted path. + "MAX_HISTORY_ITEMS": 10, + "ALLOW_CLARIFICATIONS": True, + }, +} +``` + +Endpoint: `POST /api/search/conversation/` + +```json +// Request +{ + "query": "only products", + "conversation_id": "abc-123", + "models": ["shop.Product"], + "limit": 5 +} + +// Response +{ + "conversation_id": "abc-123", + "query": "only products", + "interpreted_query": "red phone", + "clarification_needed": false, + "results": [...], + "total": 5 +} +``` + +Use `DELETE /api/search/conversation/?conversation_id=abc-123` to clear a +conversation. + +Built-in memory backends: + +| Alias | Class | Best for | +|---|---|---| +| `inmemory` | `InMemoryBackend` | Tests, single-worker dev | +| `cache` / `redis` | `DjangoCacheBackend` | Production via Django cache (Redis, memcached) | + +Bring your own by subclassing `BaseMemoryBackend` and pointing +`MEMORY_BACKEND` at the dotted path. + ## Comparison | Feature | django-graph-search | Haystack | django-elasticsearch-dsl | diff --git a/src/django_graph_search/langgraph_conversation.py b/src/django_graph_search/langgraph_conversation.py new file mode 100644 index 0000000..7d5fcdb --- /dev/null +++ b/src/django_graph_search/langgraph_conversation.py @@ -0,0 +1,269 @@ +"""Conversational search graph. + +The graph is a thin shell on top of the existing :class:`Searcher`. It is +search-first, not chat-first: nodes are deterministic by default, never +hallucinate filters, and surface a structured ``clarification_needed`` flag +when the follow-up query is too ambiguous. + +Pipeline: + +``` +load_context \u2192 interpret_followup \u2192 maybe_clarify \u2192 [execute_search] \u2192 store_context +``` + +* ``load_context`` reads the recent history from the memory backend. +* ``interpret_followup`` rewrites short follow-ups using the previous turn + (\"show me more\", \"only products\", \"what about cheaper ones\"). +* ``maybe_clarify`` decides whether the rewritten query is meaningful enough + to run a search, otherwise returns ``clarification_needed=True``. +* ``execute_search`` delegates to the existing searcher (which itself can be + the LangGraph search pipeline if enabled). +* ``store_context`` persists the user query and a compact view of results + back into the memory backend. +""" +from __future__ import annotations + +import logging +import re +import uuid +from typing import Any, Dict, Iterable, List, Optional, TypedDict + +from .memory.base import BaseMemoryBackend, ConversationEvent +from .settings import GraphSearchConfig + +log = logging.getLogger(__name__) + + +class ConversationState(TypedDict, total=False): + conversation_id: str + raw_query: str + interpreted_query: str + history: List[ConversationEvent] + models: Optional[List[str]] + limit: int + clarification_needed: bool + clarification_message: str + results: List[Dict[str, Any]] + debug: Dict[str, Any] + + +# Heuristics for short follow-ups. Patterns are intentionally conservative +# \u2014 the goal is to avoid hallucinating user intent. +_FOLLOWUP_MORE = re.compile( + r"^(more|еще|ещё|ещё\s+пожалуйста|показать ещё|show me more|next)\b", + re.IGNORECASE, +) +_FOLLOWUP_SIMILAR = re.compile( + r"^(similar|похож|аналог|like that|something else)\b", re.IGNORECASE +) +_FOLLOWUP_FILTER = re.compile( + r"^(only|just|filter|только)\s+(?P[\w\.\- ]+)$", re.IGNORECASE +) + + +def load_context_node( + state: ConversationState, *, memory: BaseMemoryBackend +) -> ConversationState: + cid = state.get("conversation_id") or str(uuid.uuid4()) + state["conversation_id"] = cid + state["history"] = memory.get_history(cid) + state.setdefault("debug", {})["history_size"] = len(state["history"]) + return state + + +def interpret_followup_node( + state: ConversationState, + *, + config: GraphSearchConfig, +) -> ConversationState: + """Resolve short follow-ups using the previous turn. + + Returns the original query verbatim when no follow-up pattern matches. + Never invents filters that have no backing in the history. + """ + raw = (state.get("raw_query") or "").strip() + history = state.get("history") or [] + last_user = _last_user_event(history) + interpreted = raw + + if raw and last_user is not None: + m = _FOLLOWUP_FILTER.match(raw) + if m: + # 'only X' - reuse previous query, narrow models if value matches. + base = last_user.interpreted_query or last_user.query + interpreted = base + state["models"] = _filter_models_to_value( + m.group("value").strip(), + last_user.models or [], + ) + elif _FOLLOWUP_MORE.match(raw): + interpreted = last_user.interpreted_query or last_user.query + elif _FOLLOWUP_SIMILAR.match(raw): + interpreted = last_user.interpreted_query or last_user.query + elif len(raw) <= config.conversational.min_query_length_for_autosearch: + # Very short query \u2014 lean on previous context. + base = last_user.interpreted_query or last_user.query + interpreted = f"{base} {raw}".strip() if base else raw + + if not state.get("models") and last_user is not None and last_user.models: + state.setdefault("models", last_user.models) + + state["interpreted_query"] = interpreted + state.setdefault("debug", {})["followup_resolved"] = interpreted != raw + return state + + +def maybe_clarify_node( + state: ConversationState, *, config: GraphSearchConfig +) -> ConversationState: + interpreted = (state.get("interpreted_query") or "").strip() + if not config.conversational.allow_clarifications: + state["clarification_needed"] = False + return state + too_short = len(interpreted) < config.conversational.min_query_length_for_autosearch + has_history = bool(state.get("history")) + if too_short and not has_history: + state["clarification_needed"] = True + state["clarification_message"] = ( + "Could you give a bit more context? The query is too short to search reliably." + ) + state.setdefault("results", []) + else: + state["clarification_needed"] = False + return state + + +def execute_search_node( + state: ConversationState, + *, + searcher, +) -> ConversationState: + if state.get("clarification_needed"): + state["results"] = [] + return state + interpreted = state.get("interpreted_query") or "" + if not interpreted.strip(): + state["results"] = [] + return state + try: + results = searcher.search( + interpreted, + models=state.get("models"), + limit=int(state.get("limit") or 0) or None, + ) + except Exception as exc: # noqa: BLE001 + log.exception("Conversational search failed: %s", exc) + state.setdefault("debug", {})["search_error"] = str(exc) + results = [] + state["results"] = results + return state + + +def store_context_node( + state: ConversationState, *, memory: BaseMemoryBackend +) -> ConversationState: + cid = state["conversation_id"] + user_event = ConversationEvent( + role="user", + query=state.get("raw_query", ""), + interpreted_query=state.get("interpreted_query", ""), + models=state.get("models"), + top_results=_compact_results(state.get("results") or []), + clarification_needed=bool(state.get("clarification_needed", False)), + ) + memory.append_event(cid, user_event) + return state + + +def build_conversation_graph( + config: GraphSearchConfig, *, searcher, memory: BaseMemoryBackend +): + """Compile the conversational graph (LangGraph or fallback runner).""" + try: + from langgraph.graph import END, StateGraph # type: ignore + except Exception: # pragma: no cover - exercised when langgraph absent. + return _FallbackConversationGraph( + config=config, searcher=searcher, memory=memory + ) + + graph: Any = StateGraph(dict) + graph.add_node("load_context", lambda s: load_context_node(s, memory=memory)) + graph.add_node("interpret_followup", lambda s: interpret_followup_node(s, config=config)) + graph.add_node("maybe_clarify", lambda s: maybe_clarify_node(s, config=config)) + graph.add_node("execute_search", lambda s: execute_search_node(s, searcher=searcher)) + graph.add_node("store_context", lambda s: store_context_node(s, memory=memory)) + + graph.set_entry_point("load_context") + graph.add_edge("load_context", "interpret_followup") + graph.add_edge("interpret_followup", "maybe_clarify") + graph.add_conditional_edges( + "maybe_clarify", + lambda s: "store_context" if s.get("clarification_needed") else "execute_search", + ) + graph.add_edge("execute_search", "store_context") + graph.add_edge("store_context", END) + return graph.compile() + + +class _FallbackConversationGraph: + def __init__( + self, + *, + config: GraphSearchConfig, + searcher, + memory: BaseMemoryBackend, + ) -> None: + self.config = config + self.searcher = searcher + self.memory = memory + + def invoke(self, state: ConversationState) -> ConversationState: + state = load_context_node(state, memory=self.memory) + state = interpret_followup_node(state, config=self.config) + state = maybe_clarify_node(state, config=self.config) + if not state.get("clarification_needed"): + state = execute_search_node(state, searcher=self.searcher) + state = store_context_node(state, memory=self.memory) + return state + + +def _last_user_event( + history: Iterable[ConversationEvent], +) -> Optional[ConversationEvent]: + last: Optional[ConversationEvent] = None + for event in history: + if event.role == "user": + last = event + return last + + +def _filter_models_to_value(value: str, candidates: List[str]) -> List[str]: + """Return candidates that mention ``value``; otherwise leave unchanged.""" + if not candidates: + return [] + needle = value.strip().lower() + matched = [c for c in candidates if needle in c.lower()] + return matched or candidates + + +def _compact_results(results: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Drop heavy fields before persisting results to memory.""" + compact: List[Dict[str, Any]] = [] + for item in list(results)[:5]: + compact.append({ + "model": item.get("model"), + "pk": item.get("pk"), + "score": item.get("score"), + }) + return compact + + +__all__ = [ + "ConversationState", + "load_context_node", + "interpret_followup_node", + "maybe_clarify_node", + "execute_search_node", + "store_context_node", + "build_conversation_graph", +] diff --git a/src/django_graph_search/memory/__init__.py b/src/django_graph_search/memory/__init__.py new file mode 100644 index 0000000..e4775a9 --- /dev/null +++ b/src/django_graph_search/memory/__init__.py @@ -0,0 +1,20 @@ +"""Pluggable conversation memory backends. + +Used by the optional conversational search endpoint to remember the last few +queries / interpreted queries / result references per ``conversation_id``. +The default is the in-process backend, which is enough for single-worker +deployments and tests. + +The contract is intentionally tiny so users can plug Redis, the Django cache +framework or even a database-backed table in a few lines. +""" +from .base import BaseMemoryBackend, ConversationEvent +from .factory import build_memory_backend +from .in_memory import InMemoryBackend + +__all__ = [ + "BaseMemoryBackend", + "ConversationEvent", + "InMemoryBackend", + "build_memory_backend", +] diff --git a/src/django_graph_search/memory/base.py b/src/django_graph_search/memory/base.py new file mode 100644 index 0000000..a8c1acb --- /dev/null +++ b/src/django_graph_search/memory/base.py @@ -0,0 +1,65 @@ +"""Conversation memory contract. + +Memory backends store a small, serialisable trail of recent search events per +conversation. We deliberately avoid storing ORM instances or long blobs so +events can be persisted to Redis / cache backends without surprises. +""" +from __future__ import annotations + +import time +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class ConversationEvent: + """A single turn of the conversation. + + Only metadata about results is kept (model + pk), never the full payload. + """ + + role: str # "user" | "assistant" + query: str = "" + interpreted_query: str = "" + models: Optional[List[str]] = None + top_results: List[Dict[str, Any]] = field(default_factory=list) + clarification_needed: bool = False + timestamp: float = field(default_factory=time.time) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @classmethod + def from_dict(cls, payload: Dict[str, Any]) -> "ConversationEvent": + return cls( + role=payload.get("role", "user"), + query=payload.get("query", "") or "", + interpreted_query=payload.get("interpreted_query", "") or "", + models=payload.get("models"), + top_results=payload.get("top_results") or [], + clarification_needed=bool(payload.get("clarification_needed", False)), + timestamp=float(payload.get("timestamp") or time.time()), + ) + + +class BaseMemoryBackend(ABC): + """Minimal interface every memory backend implements.""" + + def __init__(self, max_history_items: int = 10, **options: Any) -> None: + if max_history_items < 1: + raise ValueError("max_history_items must be >= 1") + self.max_history_items = max_history_items + self.options = options + + @abstractmethod + def get_history(self, session_id: str) -> List[ConversationEvent]: + """Return events in chronological order (oldest first).""" + + @abstractmethod + def append_event(self, session_id: str, event: ConversationEvent) -> None: + """Append an event, trimming to ``max_history_items`` from the right.""" + + @abstractmethod + def clear_history(self, session_id: str) -> None: + """Drop all events for ``session_id``.""" diff --git a/src/django_graph_search/memory/django_cache.py b/src/django_graph_search/memory/django_cache.py new file mode 100644 index 0000000..eff8dbe --- /dev/null +++ b/src/django_graph_search/memory/django_cache.py @@ -0,0 +1,54 @@ +"""Memory backend that piggy-backs on the Django cache framework. + +Works with any Django cache (locmem, Redis via django-redis, memcached, ...), +which means we get a Redis-capable backend without taking a hard dependency +on a Redis client. A wholly Redis-specific backend can be added later if +needed. +""" +from __future__ import annotations + +from typing import List + +from django.core.cache import caches + +from .base import BaseMemoryBackend, ConversationEvent + + +class DjangoCacheBackend(BaseMemoryBackend): + def __init__( + self, + max_history_items: int = 10, + alias: str = "default", + key_prefix: str = "dgs:conv:", + ttl: int = 86400, + **options, + ) -> None: + super().__init__(max_history_items=max_history_items, **options) + self.alias = alias + self.key_prefix = key_prefix + self.ttl = ttl + + @property + def _cache(self): + return caches[self.alias] + + def _key(self, session_id: str) -> str: + return f"{self.key_prefix}{session_id}" + + def get_history(self, session_id: str) -> List[ConversationEvent]: + payload = self._cache.get(self._key(session_id)) or [] + return [ConversationEvent.from_dict(p) for p in payload] + + def append_event(self, session_id: str, event: ConversationEvent) -> None: + history = self.get_history(session_id) + history.append(event) + if len(history) > self.max_history_items: + history = history[-self.max_history_items:] + self._cache.set( + self._key(session_id), + [e.to_dict() for e in history], + timeout=self.ttl, + ) + + def clear_history(self, session_id: str) -> None: + self._cache.delete(self._key(session_id)) diff --git a/src/django_graph_search/memory/factory.py b/src/django_graph_search/memory/factory.py new file mode 100644 index 0000000..455f3d6 --- /dev/null +++ b/src/django_graph_search/memory/factory.py @@ -0,0 +1,47 @@ +"""Resolve a memory backend from settings.""" +from __future__ import annotations + +from typing import Any, Dict, Optional + +from django.utils.module_loading import import_string + +from ..exceptions import ConfigurationError +from .base import BaseMemoryBackend +from .in_memory import InMemoryBackend + + +_ALIASES = { + "inmemory": "django_graph_search.memory.in_memory.InMemoryBackend", + "memory": "django_graph_search.memory.in_memory.InMemoryBackend", + "cache": "django_graph_search.memory.django_cache.DjangoCacheBackend", + "django_cache": "django_graph_search.memory.django_cache.DjangoCacheBackend", + "redis": "django_graph_search.memory.django_cache.DjangoCacheBackend", +} + + +def build_memory_backend( + backend: Optional[str], + *, + max_history_items: int = 10, + options: Optional[Dict[str, Any]] = None, +) -> BaseMemoryBackend: + """Instantiate a memory backend from a short alias or dotted path. + + Defaults to :class:`InMemoryBackend` when ``backend`` is ``None``. + """ + options = dict(options or {}) + if not backend: + return InMemoryBackend(max_history_items=max_history_items, **options) + path = _ALIASES.get(backend, backend) + try: + cls = import_string(path) + except ImportError as exc: + raise ConfigurationError( + f"Cannot import memory backend '{backend}': {exc}" + ) from exc + instance = cls(max_history_items=max_history_items, **options) + if not isinstance(instance, BaseMemoryBackend): + raise ConfigurationError( + f"Memory backend '{backend}' must subclass BaseMemoryBackend." + ) + return instance diff --git a/src/django_graph_search/memory/in_memory.py b/src/django_graph_search/memory/in_memory.py new file mode 100644 index 0000000..001bbb1 --- /dev/null +++ b/src/django_graph_search/memory/in_memory.py @@ -0,0 +1,41 @@ +"""Process-local memory backend. + +Suitable for tests, single-worker deployments and as a drop-in default. For +multi-worker setups, use the Redis or Django cache backend instead. +""" +from __future__ import annotations + +import threading +from collections import deque +from typing import Deque, Dict, List + +from .base import BaseMemoryBackend, ConversationEvent + + +class InMemoryBackend(BaseMemoryBackend): + """Bounded per-session deque of events. + + Thread-safe via a single lock; the critical sections are tiny so + contention is not a real concern at conversational-search workloads. + """ + + def __init__(self, max_history_items: int = 10, **options) -> None: + super().__init__(max_history_items=max_history_items, **options) + self._store: Dict[str, Deque[ConversationEvent]] = {} + self._lock = threading.Lock() + + def get_history(self, session_id: str) -> List[ConversationEvent]: + with self._lock: + return list(self._store.get(session_id, ())) + + def append_event(self, session_id: str, event: ConversationEvent) -> None: + with self._lock: + bucket = self._store.get(session_id) + if bucket is None: + bucket = deque(maxlen=self.max_history_items) + self._store[session_id] = bucket + bucket.append(event) + + def clear_history(self, session_id: str) -> None: + with self._lock: + self._store.pop(session_id, None) diff --git a/src/django_graph_search/settings.py b/src/django_graph_search/settings.py index c596b64..22ed070 100644 --- a/src/django_graph_search/settings.py +++ b/src/django_graph_search/settings.py @@ -53,6 +53,15 @@ "OPTIONS": {}, }, }, + "CONVERSATIONAL": { + "ENABLED": False, + "MEMORY_BACKEND": "inmemory", + "MEMORY_OPTIONS": {}, + "MAX_HISTORY_ITEMS": 10, + "ALLOW_CLARIFICATIONS": True, + "MIN_QUERY_LENGTH_FOR_AUTOSEARCH": 2, + "FOLLOWUP_GRAPH": "django_graph_search.langgraph_conversation.build_conversation_graph", + }, } @@ -108,6 +117,17 @@ class LangGraphConfig: llm: LLMConfig = field(default_factory=LLMConfig) +@dataclass(frozen=True) +class ConversationalConfig: + enabled: bool = False + memory_backend: str = "inmemory" + memory_options: Dict[str, Any] = field(default_factory=dict) + max_history_items: int = 10 + allow_clarifications: bool = True + min_query_length_for_autosearch: int = 2 + followup_graph: str = "django_graph_search.langgraph_conversation.build_conversation_graph" + + @dataclass(frozen=True) class GraphSearchConfig: models: List[ModelConfig] @@ -121,6 +141,7 @@ class GraphSearchConfig: delta_indexing: bool cache: CacheConfig langgraph: LangGraphConfig = field(default_factory=LangGraphConfig) + conversational: ConversationalConfig = field(default_factory=ConversationalConfig) def _merge_dicts(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: @@ -212,6 +233,7 @@ def get_settings() -> GraphSearchConfig: ) langgraph_cfg = _build_langgraph_config(merged.get("LANGGRAPH") or {}) + conversational_cfg = _build_conversational_config(merged.get("CONVERSATIONAL") or {}) # Validate backend paths early _load_backend(vector_store.backend) @@ -230,6 +252,7 @@ def get_settings() -> GraphSearchConfig: delta_indexing=bool(merged.get("DELTA_INDEXING", False)), cache=cache_cfg, langgraph=langgraph_cfg, + conversational=conversational_cfg, ) @@ -273,3 +296,33 @@ def _build_langgraph_config(payload: Dict[str, Any]) -> LangGraphConfig: llm=llm_cfg, ) + +def _build_conversational_config(payload: Dict[str, Any]) -> ConversationalConfig: + if not isinstance(payload, dict): + raise ConfigurationError("CONVERSATIONAL must be a dict.") + defaults = DEFAULTS["CONVERSATIONAL"] + merged = _merge_dicts(defaults, payload) + max_history = int(merged.get("MAX_HISTORY_ITEMS", 10)) + if max_history < 1: + raise ConfigurationError("CONVERSATIONAL.MAX_HISTORY_ITEMS must be >= 1.") + min_qlen = int(merged.get("MIN_QUERY_LENGTH_FOR_AUTOSEARCH", 2)) + if min_qlen < 0: + raise ConfigurationError( + "CONVERSATIONAL.MIN_QUERY_LENGTH_FOR_AUTOSEARCH must be >= 0." + ) + options = merged.get("MEMORY_OPTIONS") or {} + if not isinstance(options, dict): + raise ConfigurationError("CONVERSATIONAL.MEMORY_OPTIONS must be a dict.") + return ConversationalConfig( + enabled=bool(merged.get("ENABLED", False)), + memory_backend=str(merged.get("MEMORY_BACKEND") or "inmemory"), + memory_options=options, + max_history_items=max_history, + allow_clarifications=bool(merged.get("ALLOW_CLARIFICATIONS", True)), + min_query_length_for_autosearch=min_qlen, + followup_graph=str( + merged.get("FOLLOWUP_GRAPH") + or "django_graph_search.langgraph_conversation.build_conversation_graph" + ), + ) + diff --git a/src/django_graph_search/urls.py b/src/django_graph_search/urls.py index 0307842..193ed9b 100644 --- a/src/django_graph_search/urls.py +++ b/src/django_graph_search/urls.py @@ -1,10 +1,15 @@ from django.urls import path -from .views import SearchAPIView, SimilarAPIView +from .views import ConversationalSearchAPIView, SearchAPIView, SimilarAPIView urlpatterns = [ path("", SearchAPIView.as_view(), name="graph_search"), path("similar///", SimilarAPIView.as_view(), name="graph_search_similar"), + path( + "conversation/", + ConversationalSearchAPIView.as_view(), + name="graph_search_conversation", + ), ] diff --git a/src/django_graph_search/views.py b/src/django_graph_search/views.py index cc1773e..8696e93 100644 --- a/src/django_graph_search/views.py +++ b/src/django_graph_search/views.py @@ -1,12 +1,21 @@ from __future__ import annotations +import json +import logging +import uuid +from typing import Any, Dict, List, Optional + from django.apps import apps from django.http import JsonResponse +from django.utils.decorators import method_decorator from django.views import View +from django.views.decorators.csrf import csrf_exempt from .searcher import Searcher from .settings import get_settings +log = logging.getLogger(__name__) + class SearchAPIView(View): def get(self, request, *args, **kwargs): @@ -26,6 +35,144 @@ def get(self, request, *args, **kwargs): ) +@method_decorator(csrf_exempt, name="dispatch") +class ConversationalSearchAPIView(View): + """Session-aware semantic search. + + Accepts ``POST`` with a JSON body or a form payload: + + .. code-block:: json + + { + "query": "only products", + "conversation_id": "abc-123", + "models": ["shop.Product"], + "limit": 5 + } + + Returns: + + .. code-block:: json + + { + "conversation_id": "abc-123", + "query": "only products", + "interpreted_query": "red phone", + "clarification_needed": false, + "results": [...], + "total": 5 + } + + The endpoint disables itself with HTTP 404 when + ``CONVERSATIONAL.ENABLED`` is false, so it is safe to leave the URL + registered globally. + """ + + # Module-level memory cache: keep a single backend per process so the + # in-memory variant actually persists across requests in tests and dev. + _memory_cache: Dict[str, Any] = {} + + def post(self, request, *args, **kwargs): + cfg = get_settings() + if not cfg.conversational.enabled: + return JsonResponse({"error": "Conversational search is disabled."}, status=404) + payload = self._parse_body(request) + query = (payload.get("query") or "").strip() + if not query: + return JsonResponse({"error": "Parameter 'query' is required."}, status=400) + conversation_id = payload.get("conversation_id") or str(uuid.uuid4()) + models = payload.get("models") + if isinstance(models, str): + models = [m.strip() for m in models.split(",") if m.strip()] + limit = payload.get("limit") + try: + limit_value = int(limit) if limit is not None else None + except (TypeError, ValueError): + limit_value = None + + memory = self._get_memory_backend(cfg) + searcher = Searcher() + graph = self._build_graph(cfg, searcher=searcher, memory=memory) + state = { + "conversation_id": conversation_id, + "raw_query": query, + "models": list(models) if models else None, + "limit": limit_value or cfg.default_results_limit, + } + try: + out = graph.invoke(state) + except Exception as exc: # noqa: BLE001 + log.exception("Conversational graph failed: %s", exc) + return JsonResponse({"error": "Internal error in conversational graph."}, status=500) + + body = { + "conversation_id": out.get("conversation_id", conversation_id), + "query": query, + "interpreted_query": out.get("interpreted_query", query), + "clarification_needed": bool(out.get("clarification_needed", False)), + "clarification_message": out.get("clarification_message", ""), + "results": out.get("results") or [], + "total": len(out.get("results") or []), + } + return JsonResponse(body, status=200) + + # GET is handy for clearing the conversation. + def delete(self, request, *args, **kwargs): + cfg = get_settings() + if not cfg.conversational.enabled: + return JsonResponse({"error": "Conversational search is disabled."}, status=404) + cid = request.GET.get("conversation_id") or self._parse_body(request).get("conversation_id") + if not cid: + return JsonResponse({"error": "conversation_id is required."}, status=400) + memory = self._get_memory_backend(cfg) + memory.clear_history(cid) + return JsonResponse({"conversation_id": cid, "cleared": True}, status=200) + + # ----------------------------------------------------------- helpers + + @staticmethod + def _parse_body(request) -> Dict[str, Any]: + if request.body: + content_type = (request.META.get("CONTENT_TYPE") or "").split(";")[0].strip() + if content_type == "application/json": + try: + return json.loads(request.body.decode("utf-8")) or {} + except (ValueError, UnicodeDecodeError): + return {} + # Fall back to form data / query params. + merged: Dict[str, Any] = {} + merged.update(request.POST.dict() if hasattr(request.POST, "dict") else {}) + merged.update(request.GET.dict() if hasattr(request.GET, "dict") else {}) + return merged + + @classmethod + def _get_memory_backend(cls, cfg): + from .memory import build_memory_backend + + cache_key = ( + cfg.conversational.memory_backend, + tuple(sorted((cfg.conversational.memory_options or {}).items())), + cfg.conversational.max_history_items, + ) + existing = cls._memory_cache.get(cache_key) + if existing is not None: + return existing + backend = build_memory_backend( + cfg.conversational.memory_backend, + max_history_items=cfg.conversational.max_history_items, + options=cfg.conversational.memory_options, + ) + cls._memory_cache[cache_key] = backend + return backend + + @staticmethod + def _build_graph(cfg, *, searcher, memory): + from django.utils.module_loading import import_string + + factory = import_string(cfg.conversational.followup_graph) + return factory(cfg, searcher=searcher, memory=memory) + + class SimilarAPIView(View): def get(self, request, model: str, pk: str, *args, **kwargs): if "." not in model: diff --git a/tests/test_conversational_search.py b/tests/test_conversational_search.py new file mode 100644 index 0000000..6c0927a --- /dev/null +++ b/tests/test_conversational_search.py @@ -0,0 +1,331 @@ +"""Tests for the optional conversational search endpoint and graph.""" +from __future__ import annotations + +import json +from typing import Any, Dict, List +from unittest import mock + +import pytest +from django.conf import settings as django_settings +from django.test import RequestFactory + +from django_graph_search.langgraph_conversation import ( + ConversationState, + build_conversation_graph, + interpret_followup_node, + maybe_clarify_node, +) +from django_graph_search.memory import ( + BaseMemoryBackend, + ConversationEvent, + InMemoryBackend, + build_memory_backend, +) +from django_graph_search.settings import get_settings +from django_graph_search.views import ConversationalSearchAPIView + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def graph_search_settings(): + original = getattr(django_settings, "GRAPH_SEARCH", None) + get_settings.cache_clear() + # Reset the per-process memory cache between tests so each test starts + # from a clean slate. + ConversationalSearchAPIView._memory_cache.clear() + + def _apply(payload): + django_settings.GRAPH_SEARCH = payload + get_settings.cache_clear() + return get_settings() + + yield _apply + + if original is None and hasattr(django_settings, "GRAPH_SEARCH"): + delattr(django_settings, "GRAPH_SEARCH") + elif original is not None: + django_settings.GRAPH_SEARCH = original + get_settings.cache_clear() + ConversationalSearchAPIView._memory_cache.clear() + + +def _base_settings(extra=None): + payload = { + "MODELS": [], + "VECTOR_STORE": {"BACKEND": "django_graph_search.backends.ChromaDBBackend"}, + "EMBEDDINGS": { + "default": { + "BACKEND": "tests.dummy_embedding_backend.DummyEmbeddingBackend", + "MODEL_NAME": "x", + } + }, + "CONVERSATIONAL": {"ENABLED": True}, + } + if extra: + payload.update(extra) + return payload + + +# --------------------------------------------------------------------------- +# Memory backends +# --------------------------------------------------------------------------- + + +def test_in_memory_backend_appends_and_truncates(): + backend = InMemoryBackend(max_history_items=2) + for i in range(3): + backend.append_event( + "s1", ConversationEvent(role="user", query=f"q{i}", interpreted_query=f"q{i}") + ) + history = backend.get_history("s1") + assert [e.query for e in history] == ["q1", "q2"] + + +def test_in_memory_backend_clear_history(): + backend = InMemoryBackend(max_history_items=5) + backend.append_event("s1", ConversationEvent(role="user", query="q")) + backend.clear_history("s1") + assert backend.get_history("s1") == [] + + +def test_factory_returns_in_memory_for_alias(): + backend = build_memory_backend("inmemory", max_history_items=3) + assert isinstance(backend, InMemoryBackend) + assert backend.max_history_items == 3 + + +def test_conversation_event_round_trip(): + ev = ConversationEvent(role="user", query="x", models=["m"], top_results=[{"pk": 1}]) + other = ConversationEvent.from_dict(ev.to_dict()) + assert other.query == "x" + assert other.models == ["m"] + assert other.top_results == [{"pk": 1}] + + +# --------------------------------------------------------------------------- +# Nodes +# --------------------------------------------------------------------------- + + +class _StubSearcher: + def __init__(self, results): + self._results = results + self.calls: List[Dict[str, Any]] = [] + + def search(self, query, models=None, limit=None): + self.calls.append({"query": query, "models": models, "limit": limit}) + return list(self._results) + + +def test_interpret_followup_uses_previous_query_for_short_input(graph_search_settings): + cfg = graph_search_settings(_base_settings()) + history = [ + ConversationEvent(role="user", query="red phone", interpreted_query="red phone"), + ] + state: ConversationState = { + "raw_query": "more", + "history": history, + } + out = interpret_followup_node(state, config=cfg) + assert out["interpreted_query"] == "red phone" + + +def test_interpret_followup_only_filter_narrows_models(graph_search_settings): + cfg = graph_search_settings(_base_settings()) + history = [ + ConversationEvent( + role="user", + query="red phone", + interpreted_query="red phone", + models=["shop.Product", "blog.Post"], + ), + ] + state: ConversationState = { + "raw_query": "only Product", + "history": history, + } + out = interpret_followup_node(state, config=cfg) + assert out["interpreted_query"] == "red phone" + assert out["models"] == ["shop.Product"] + + +def test_maybe_clarify_triggers_for_too_short_query_without_history(graph_search_settings): + cfg = graph_search_settings(_base_settings({ + "CONVERSATIONAL": { + "ENABLED": True, + "MIN_QUERY_LENGTH_FOR_AUTOSEARCH": 3, + }, + })) + state: ConversationState = {"interpreted_query": "x", "history": []} + out = maybe_clarify_node(state, config=cfg) + assert out["clarification_needed"] is True + assert out.get("clarification_message") + + +def test_maybe_clarify_skipped_when_disabled(graph_search_settings): + cfg = graph_search_settings(_base_settings({ + "CONVERSATIONAL": { + "ENABLED": True, + "ALLOW_CLARIFICATIONS": False, + }, + })) + state: ConversationState = {"interpreted_query": "x", "history": []} + out = maybe_clarify_node(state, config=cfg) + assert out["clarification_needed"] is False + + +# --------------------------------------------------------------------------- +# Graph (fallback runner) end-to-end +# --------------------------------------------------------------------------- + + +def test_graph_executes_search_and_persists_history(graph_search_settings): + cfg = graph_search_settings(_base_settings()) + memory = InMemoryBackend(max_history_items=5) + searcher = _StubSearcher([{"model": "shop.Product", "pk": 1, "score": 0.9}]) + graph = build_conversation_graph(cfg, searcher=searcher, memory=memory) + + out = graph.invoke({"raw_query": "red phone", "limit": 5}) + cid = out["conversation_id"] + assert out["results"][0]["pk"] == 1 + assert searcher.calls and searcher.calls[0]["query"] == "red phone" + history = memory.get_history(cid) + assert len(history) == 1 and history[0].query == "red phone" + + +def test_graph_followup_uses_previous_context(graph_search_settings): + cfg = graph_search_settings(_base_settings()) + memory = InMemoryBackend(max_history_items=5) + searcher = _StubSearcher([{"model": "shop.Product", "pk": 1, "score": 0.9}]) + graph = build_conversation_graph(cfg, searcher=searcher, memory=memory) + + first = graph.invoke({"raw_query": "red phone", "limit": 5}) + cid = first["conversation_id"] + second = graph.invoke({"raw_query": "more", "conversation_id": cid, "limit": 5}) + # interpret_followup should rewrite "more" using the previous turn. + assert second["interpreted_query"] == "red phone" + assert searcher.calls[-1]["query"] == "red phone" + + +def test_graph_returns_clarification_for_ambiguous_short_input(graph_search_settings): + cfg = graph_search_settings(_base_settings({ + "CONVERSATIONAL": { + "ENABLED": True, + "MIN_QUERY_LENGTH_FOR_AUTOSEARCH": 3, + }, + })) + memory = InMemoryBackend(max_history_items=5) + searcher = _StubSearcher([]) + graph = build_conversation_graph(cfg, searcher=searcher, memory=memory) + out = graph.invoke({"raw_query": "x", "limit": 5}) + assert out["clarification_needed"] is True + assert out["results"] == [] + # Search should not have been invoked. + assert searcher.calls == [] + + +# --------------------------------------------------------------------------- +# View +# --------------------------------------------------------------------------- + + +@pytest.mark.django_db +def test_view_returns_404_when_disabled(graph_search_settings): + graph_search_settings({ + "MODELS": [], + "VECTOR_STORE": {"BACKEND": "django_graph_search.backends.ChromaDBBackend"}, + "EMBEDDINGS": { + "default": { + "BACKEND": "tests.dummy_embedding_backend.DummyEmbeddingBackend", + "MODEL_NAME": "x", + } + }, + }) + factory = RequestFactory() + request = factory.post( + "/api/search/conversation/", + data=json.dumps({"query": "anything"}), + content_type="application/json", + ) + response = ConversationalSearchAPIView.as_view()(request) + assert response.status_code == 404 + + +@pytest.mark.django_db +def test_view_runs_full_pipeline(graph_search_settings): + graph_search_settings(_base_settings()) + factory = RequestFactory() + request = factory.post( + "/api/search/conversation/", + data=json.dumps({"query": "red phone"}), + content_type="application/json", + ) + with mock.patch("django_graph_search.views.Searcher") as searcher_cls: + searcher_cls.return_value.search.return_value = [ + {"model": "shop.Product", "pk": 1, "score": 0.5} + ] + response = ConversationalSearchAPIView.as_view()(request) + body = json.loads(response.content.decode()) + assert response.status_code == 200 + assert body["query"] == "red phone" + assert body["interpreted_query"] == "red phone" + assert body["clarification_needed"] is False + assert body["results"][0]["pk"] == 1 + assert body["conversation_id"] + + +@pytest.mark.django_db +def test_view_followup_uses_history_across_two_calls(graph_search_settings): + graph_search_settings(_base_settings()) + factory = RequestFactory() + + def post(payload): + request = factory.post( + "/api/search/conversation/", + data=json.dumps(payload), + content_type="application/json", + ) + with mock.patch("django_graph_search.views.Searcher") as searcher_cls: + searcher_cls.return_value.search.return_value = [ + {"model": "shop.Product", "pk": 1, "score": 0.5} + ] + return ConversationalSearchAPIView.as_view()(request), searcher_cls + + response, _ = post({"query": "red phone"}) + cid = json.loads(response.content.decode())["conversation_id"] + + response2, searcher_cls2 = post({"query": "more", "conversation_id": cid}) + body2 = json.loads(response2.content.decode()) + assert body2["interpreted_query"] == "red phone" + # The mocked searcher must have been queried with "red phone". + args, kwargs = searcher_cls2.return_value.search.call_args + actual_query = args[0] if args else kwargs.get("query") + assert actual_query == "red phone" + + +@pytest.mark.django_db +def test_view_clear_history(graph_search_settings): + graph_search_settings(_base_settings()) + factory = RequestFactory() + # Seed memory by issuing a real call. + request = factory.post( + "/api/search/conversation/", + data=json.dumps({"query": "hello"}), + content_type="application/json", + ) + with mock.patch("django_graph_search.views.Searcher") as searcher_cls: + searcher_cls.return_value.search.return_value = [] + response = ConversationalSearchAPIView.as_view()(request) + cid = json.loads(response.content.decode())["conversation_id"] + + # Now clear it. + delete_request = factory.delete( + f"/api/search/conversation/?conversation_id={cid}", + ) + delete_response = ConversationalSearchAPIView.as_view()(delete_request) + assert delete_response.status_code == 200 + assert json.loads(delete_response.content.decode())["cleared"] is True