From 2efc98348c2276d2da94f7c650c9f0f43bcc1e50 Mon Sep 17 00:00:00 2001 From: svalench Date: Fri, 8 May 2026 07:04:06 +0000 Subject: [PATCH] feat(langgraph): optional LangGraph search pipeline with query expansion and reranking Adds an opt-in orchestration layer (Sprint 1+2 combined) on top of the existing components without changing the public API. Highlights: * New LANGGRAPH settings section with safe defaults (disabled). * New langgraph_agent module with TypedDict state and 5 nodes: analyze_query, expand_query, vector_search, rerank, postprocess. * In-tree fallback runner so the package works without the langgraph package installed; uses langgraph.StateGraph when available. * New llm/ subpackage with BaseLLMBackend contract, DummyLLMBackend (deterministic, dependency-free) and a factory. * Multi-query merge with per-doc score consolidation and dedup. * Searcher.search and find_similar are unchanged in signature; when the graph fails the searcher falls back to the legacy linear path (FALLBACK_ON_ERROR=True by default). * 13 new tests (settings validation, individual nodes, end-to-end searcher behaviour with and without LangGraph). * README section documenting how to enable and configure the pipeline. Backwards compatibility: * All 8 pre-existing tests still pass unchanged. * When LANGGRAPH.ENABLED is False (default), Searcher.search executes the original linear path byte-for-byte. --- README.md | 55 ++++ setup.cfg | 4 + src/django_graph_search/langgraph_agent.py | 353 +++++++++++++++++++++ src/django_graph_search/llm/__init__.py | 24 ++ src/django_graph_search/llm/base.py | 69 ++++ src/django_graph_search/llm/dummy.py | 71 +++++ src/django_graph_search/llm/factory.py | 37 +++ src/django_graph_search/searcher.py | 106 ++++++- src/django_graph_search/settings.py | 86 ++++- tests/test_langgraph_search.py | 345 ++++++++++++++++++++ 10 files changed, 1140 insertions(+), 10 deletions(-) create mode 100644 src/django_graph_search/langgraph_agent.py create mode 100644 src/django_graph_search/llm/__init__.py create mode 100644 src/django_graph_search/llm/base.py create mode 100644 src/django_graph_search/llm/dummy.py create mode 100644 src/django_graph_search/llm/factory.py create mode 100644 tests/test_langgraph_search.py diff --git a/README.md b/README.md index 70577b8..eb72b43 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,61 @@ Enable `DELTA_INDEXING: True` to skip objects that haven’t changed since last | `redis` | `OPTIONS.alias` | Production | | `db` | `OPTIONS.alias` | Simple setup | +## LangGraph-powered search pipeline (optional) + +Starting with this version, `django-graph-search` ships with an **optional** +orchestration layer built on top of [LangGraph](https://langchain-ai.github.io/langgraph/). +It is disabled by default; the public API (`Searcher.search`, +`Searcher.find_similar`, REST endpoints) is fully backwards-compatible. + +When enabled, the pipeline runs as a small graph: + +``` +analyze_query → [expand_query] → vector_search → [rerank] → postprocess +``` + +Steps in `[brackets]` are toggled via settings, and each one degrades +gracefully: if the LLM backend fails or is not configured, the pipeline keeps +working using the deterministic vector search. + +```python +GRAPH_SEARCH = { + # ... your existing config ... + "LANGGRAPH": { + "ENABLED": True, # Master switch. + "QUERY_EXPANSION": True, # Generate semantic reformulations. + "RERANKING": True, # Rerank top-K candidates. + "MAX_EXPANDED_QUERIES": 3, + "RERANK_TOP_K": 20, + "TIMEOUT_SECONDS": 15, + "MAX_QUERY_LENGTH": 1024, + "FALLBACK_ON_ERROR": True, # Fall back to legacy search on graph errors. + "USE_FOR_SIMILAR": False, # Route find_similar through the graph. + "LLM": { + # Leave BACKEND=None to use the deterministic dummy backend. + "BACKEND": None, + "MODEL": None, + "OPTIONS": {}, + }, + }, +} +``` + +### Bring your own LLM backend + +Implement `django_graph_search.llm.BaseLLMBackend` and point +`LANGGRAPH.LLM.BACKEND` at the dotted path. The contract is intentionally +tiny — `expand_query(query, models, max_variants)` and +`rerank(query, candidates, top_k)` — so you can wrap any provider +(OpenAI, Ollama, vLLM, your in-house service) in a few lines. + +### Why optional? + +The library refuses to add hard dependencies on `langgraph` or any LLM SDK. +If `langgraph` is not installed, the pipeline transparently uses an in-tree +sequential runner with the same node structure, so behaviour and tests stay +identical. + ## Comparison | Feature | django-graph-search | Haystack | django-elasticsearch-dsl | diff --git a/setup.cfg b/setup.cfg index 3ac5bc4..ceeb866 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,8 +40,12 @@ qdrant = qdrant-client>=1.6.0 test = pytest>=9.0.0 + pytest-django>=4.0 +langgraph = + langgraph>=0.2.0 all = chromadb>=0.5.0 faiss-cpu>=1.7.4 qdrant-client>=1.6.0 + langgraph>=0.2.0 diff --git a/src/django_graph_search/langgraph_agent.py b/src/django_graph_search/langgraph_agent.py new file mode 100644 index 0000000..aa272ad --- /dev/null +++ b/src/django_graph_search/langgraph_agent.py @@ -0,0 +1,353 @@ +"""LangGraph-powered orchestration layer for the search pipeline. + +This module is the optional successor to the linear flow currently +implemented in :class:`~django_graph_search.searcher.Searcher.search`. It +wraps the existing components (embedding backend, vector store, LLM backend) +in a small graph so that query analysis, expansion, vector lookup and +reranking can be composed and individually toggled via settings. + +Design goals: + +* **No hard dependency on the langgraph package.** The pipeline degrades to + a tiny in-tree runner when LangGraph is not installed. When LangGraph is + installed we use its ``StateGraph`` so users get streaming, checkpointing + and tracing for free. +* **Backwards-compatible defaults.** With ``LANGGRAPH.ENABLED = False`` + nothing in this module runs and ``Searcher.search`` behaves exactly as + before. +* **Stateless nodes.** Every node receives and returns a plain ``dict`` + state, which makes the pipeline trivial to test and serialize. +""" +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, TypedDict + +from .llm.base import BaseLLMBackend, RerankCandidate +from .settings import GraphSearchConfig + +log = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# State definition +# --------------------------------------------------------------------------- + + +class SearchState(TypedDict, total=False): + """Mutable bag of values that flows through the search graph.""" + + query: str + normalized_query: str + expanded_queries: List[str] + models: Optional[List[str]] + limit: int + rerank_top_k: int + + # Results bookkeeping. + raw_results: List[Any] # List of vector store ResultItem objects. + merged_results: List[Any] + reranked_results: List[Any] + final_results: List[Any] + + # Diagnostics. + errors: List[str] + debug: Dict[str, Any] + + +# --------------------------------------------------------------------------- +# Node implementations +# --------------------------------------------------------------------------- + + +def analyze_query_node(state: SearchState, *, config: GraphSearchConfig) -> SearchState: + """Normalize the query and seed default values. + + Kept deterministic on purpose: this node never calls an LLM, so it stays + fast and predictable. It enforces the configured ``MAX_QUERY_LENGTH``. + """ + query = (state.get("query") or "").strip() + max_len = config.langgraph.max_query_length + if max_len and len(query) > max_len: + query = query[:max_len] + state["normalized_query"] = query + state.setdefault("expanded_queries", [query] if query else []) + state.setdefault("debug", {})["normalized_length"] = len(query) + return state + + +def expand_query_node( + state: SearchState, + *, + config: GraphSearchConfig, + llm: BaseLLMBackend, +) -> SearchState: + """Generate semantic reformulations of the query via the LLM backend. + + On any failure we fall back to ``[normalized_query]`` so the rest of the + pipeline keeps working — that is the whole point of having a graph. + """ + base = state.get("normalized_query") or state.get("query", "") + if not base: + state["expanded_queries"] = [] + return state + max_variants = config.langgraph.max_expanded_queries + try: + variants = llm.expand_query(base, models=state.get("models"), max_variants=max_variants) + except Exception as exc: # noqa: BLE001 - LLM errors must never poison search. + log.warning("Query expansion failed, falling back to original query: %s", exc) + state.setdefault("errors", []).append(f"expand_query: {exc}") + variants = [base] + + # Make sure the original query is always present and dedup while preserving order. + seen = set() + ordered: List[str] = [] + for item in [base, *variants]: + if not item or item in seen: + continue + seen.add(item) + ordered.append(item) + if len(ordered) >= max_variants: + break + state["expanded_queries"] = ordered + state.setdefault("debug", {})["expanded_count"] = len(ordered) + return state + + +def vector_search_node( + state: SearchState, + *, + embedding_backend, + vector_store, +) -> SearchState: + """Run the vector store query for every expanded query and merge hits. + + Results are deduplicated by ``(model, pk)``; we keep the highest score per + document because the underlying stores can return slightly different + scores for related queries. + """ + queries = state.get("expanded_queries") or [state.get("normalized_query") or ""] + queries = [q for q in queries if q] + limit = int(state.get("limit") or 0) or 20 + + # Multi-query merge keyed by document id. + merged: Dict[str, Any] = {} + for q in queries: + try: + vec = embedding_backend.embed(q) + hits = vector_store.search(vec, limit=limit, filters=None) + except Exception as exc: # noqa: BLE001 + log.warning("Vector search failed for query=%r: %s", q, exc) + state.setdefault("errors", []).append(f"vector_search: {exc}") + continue + for hit in hits: + key = _doc_key(hit) + existing = merged.get(key) + if existing is None or _score_value(hit) > _score_value(existing): + merged[key] = hit + + results = list(merged.values()) + + models_filter = state.get("models") + if models_filter: + allowed = set(models_filter) + results = [item for item in results if item.metadata.get("model") in allowed] + + # Stable order: best score first. + results.sort(key=_score_value, reverse=True) + + state["raw_results"] = results + state["merged_results"] = results + state.setdefault("debug", {})["candidate_count"] = len(results) + return state + + +def rerank_results_node( + state: SearchState, + *, + config: GraphSearchConfig, + llm: BaseLLMBackend, +) -> SearchState: + """Optionally rerank the top-K candidates via the LLM backend.""" + candidates = state.get("merged_results") or [] + if not candidates: + state["reranked_results"] = [] + return state + top_k = int(state.get("rerank_top_k") or config.langgraph.rerank_top_k) + head = candidates[:top_k] + tail = candidates[top_k:] + + rerank_inputs = [ + RerankCandidate( + id=_doc_key(item), + text=getattr(item, "text", "") or "", + score=_score_value(item), + metadata=dict(item.metadata or {}), + ) + for item in head + ] + try: + reranked = llm.rerank( + state.get("normalized_query") or "", + rerank_inputs, + top_k=top_k, + ) + except Exception as exc: # noqa: BLE001 + log.warning("Reranking failed, keeping vector order: %s", exc) + state.setdefault("errors", []).append(f"rerank: {exc}") + state["reranked_results"] = candidates + return state + + by_id = {_doc_key(item): item for item in head} + ordered_head: List[Any] = [] + for cand in reranked: + item = by_id.pop(cand.id, None) + if item is not None: + ordered_head.append(item) + # Append any items the reranker dropped, preserving original order. + for item in head: + key = _doc_key(item) + if key in by_id: + ordered_head.append(item) + by_id.pop(key, None) + + state["reranked_results"] = ordered_head + tail + return state + + +def postprocess_results_node(state: SearchState) -> SearchState: + """Apply ``limit`` and finalize the result list.""" + results = state.get("reranked_results") or state.get("merged_results") or state.get( + "raw_results" + ) or [] + limit = int(state.get("limit") or 0) + if limit and limit > 0: + results = results[:limit] + state["final_results"] = list(results) + return state + + +# --------------------------------------------------------------------------- +# Graph construction +# --------------------------------------------------------------------------- + + +def build_search_graph(config: GraphSearchConfig, *, embedding_backend, vector_store, llm: BaseLLMBackend): + """Build and compile the search graph. + + When the ``langgraph`` package is available we return a compiled + LangGraph ``StateGraph``. Otherwise we return :class:`_FallbackGraph` so + the rest of the code stays identical. + """ + try: + from langgraph.graph import END, StateGraph # type: ignore + except Exception: # pragma: no cover - exercised when langgraph absent. + return _FallbackGraph(config=config, embedding_backend=embedding_backend, + vector_store=vector_store, llm=llm) + + graph: Any = StateGraph(dict) + graph.add_node("analyze_query", lambda s: analyze_query_node(s, config=config)) + graph.add_node( + "expand_query", + lambda s: expand_query_node(s, config=config, llm=llm), + ) + graph.add_node( + "vector_search", + lambda s: vector_search_node(s, embedding_backend=embedding_backend, vector_store=vector_store), + ) + graph.add_node( + "rerank_results", + lambda s: rerank_results_node(s, config=config, llm=llm), + ) + graph.add_node("postprocess_results", lambda s: postprocess_results_node(s)) + + graph.set_entry_point("analyze_query") + graph.add_conditional_edges( + "analyze_query", + lambda _s: "expand_query" if config.langgraph.query_expansion else "vector_search", + ) + graph.add_edge("expand_query", "vector_search") + graph.add_conditional_edges( + "vector_search", + lambda _s: "rerank_results" if config.langgraph.reranking else "postprocess_results", + ) + graph.add_edge("rerank_results", "postprocess_results") + graph.add_edge("postprocess_results", END) + return graph.compile() + + +# --------------------------------------------------------------------------- +# Fallback runner +# --------------------------------------------------------------------------- + + +class _FallbackGraph: + """Tiny sequential runner used when the langgraph package is missing. + + It mirrors the conditional structure of :func:`build_search_graph` so + behaviour stays identical regardless of LangGraph availability. + """ + + def __init__( + self, + *, + config: GraphSearchConfig, + embedding_backend, + vector_store, + llm: BaseLLMBackend, + ) -> None: + self.config = config + self.embedding_backend = embedding_backend + self.vector_store = vector_store + self.llm = llm + + def invoke(self, state: SearchState) -> SearchState: + state = analyze_query_node(state, config=self.config) + if self.config.langgraph.query_expansion: + state = expand_query_node(state, config=self.config, llm=self.llm) + state = vector_search_node( + state, + embedding_backend=self.embedding_backend, + vector_store=self.vector_store, + ) + if self.config.langgraph.reranking: + state = rerank_results_node(state, config=self.config, llm=self.llm) + state = postprocess_results_node(state) + return state + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _doc_key(item: Any) -> str: + md = getattr(item, "metadata", {}) or {} + return f"{md.get('model')}::{md.get('pk')}" + + +def _score_value(item: Any) -> float: + score = getattr(item, "score", None) + try: + return float(score) if score is not None else 0.0 + except (TypeError, ValueError): + return 0.0 + + +def resolve_graph_factory(dotted_path: str) -> Callable[..., Any]: + """Lazily import a graph factory (used by the searcher).""" + from django.utils.module_loading import import_string + + return import_string(dotted_path) + + +__all__ = [ + "SearchState", + "analyze_query_node", + "expand_query_node", + "vector_search_node", + "rerank_results_node", + "postprocess_results_node", + "build_search_graph", + "resolve_graph_factory", +] diff --git a/src/django_graph_search/llm/__init__.py b/src/django_graph_search/llm/__init__.py new file mode 100644 index 0000000..8f85f8d --- /dev/null +++ b/src/django_graph_search/llm/__init__.py @@ -0,0 +1,24 @@ +"""Pluggable LLM backends used by the optional LangGraph search pipeline. + +This subpackage is intentionally self-contained: it has no hard dependency on +LangGraph itself, so the rest of the library keeps working when LangGraph or a +remote LLM SDK is not installed. + +Public surface: + +* ``BaseLLMBackend`` — the contract every LLM backend implements. +* ``DummyLLMBackend`` — deterministic, dependency-free backend used in tests + and as the default fallback when no LLM is configured. +* ``build_llm_backend`` — factory that resolves the configured backend from + :class:`~django_graph_search.settings.LLMConfig`. +""" +from .base import BaseLLMBackend, RerankCandidate +from .dummy import DummyLLMBackend +from .factory import build_llm_backend + +__all__ = [ + "BaseLLMBackend", + "RerankCandidate", + "DummyLLMBackend", + "build_llm_backend", +] diff --git a/src/django_graph_search/llm/base.py b/src/django_graph_search/llm/base.py new file mode 100644 index 0000000..a5945a8 --- /dev/null +++ b/src/django_graph_search/llm/base.py @@ -0,0 +1,69 @@ +"""Base contract for pluggable LLM backends used by the LangGraph pipeline.""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Sequence + + +@dataclass(frozen=True) +class RerankCandidate: + """Lightweight view of a search hit passed to a reranker. + + The reranker only needs a stable identifier, the textual content and the + original score. We deliberately avoid leaking ORM objects into the LLM + layer, which keeps backends serializable and easier to reason about. + """ + + id: str + text: str + score: float + metadata: Dict[str, Any] + + +class BaseLLMBackend(ABC): + """Minimal interface every LLM backend implements. + + Backends should be cheap to instantiate; heavy clients (HTTP sessions, + model warm-up, ...) belong in the constructor so we pay the cost once and + reuse the instance from the orchestrator. + """ + + def __init__(self, model: Optional[str] = None, **options: Any) -> None: + self.model = model + self.options = options + + @abstractmethod + def expand_query( + self, + query: str, + models: Optional[Iterable[str]] = None, + max_variants: int = 3, + ) -> List[str]: + """Return up to ``max_variants`` semantic reformulations of ``query``. + + Implementations MUST return at least the original query in the result + so callers can blindly iterate the returned list and never lose the + user's intent. + """ + + @abstractmethod + def rerank( + self, + query: str, + candidates: Sequence[RerankCandidate], + top_k: Optional[int] = None, + ) -> List[RerankCandidate]: + """Reorder ``candidates`` by relevance to ``query``. + + Implementations MUST be tolerant of empty inputs and never raise on + non-fatal errors — return the input order on failure instead. + """ + + # Optional capability hints. Subclasses can override them so the + # orchestrator can decide whether it is worth calling a node at all. + def supports_expansion(self) -> bool: + return True + + def supports_rerank(self) -> bool: + return True diff --git a/src/django_graph_search/llm/dummy.py b/src/django_graph_search/llm/dummy.py new file mode 100644 index 0000000..a36504d --- /dev/null +++ b/src/django_graph_search/llm/dummy.py @@ -0,0 +1,71 @@ +"""Deterministic, dependency-free LLM backend used as a safe default. + +Real LLM backends bring heavy dependencies and network calls. The dummy +backend lets the rest of the pipeline be exercised in unit tests, in CI and +in early-stage projects without any external service. It is deliberately +boring: query expansion produces simple morphological variants, and reranking +is a stable sort by score. +""" +from __future__ import annotations + +import re +from typing import Iterable, List, Optional, Sequence + +from .base import BaseLLMBackend, RerankCandidate + + +class DummyLLMBackend(BaseLLMBackend): + """Tiny, predictable LLM stub. + + The backend never raises and never makes a network call. It is suitable + as a fallback when ``LANGGRAPH.LLM.BACKEND`` is unset and as a + deterministic baseline in tests. + """ + + _WORD_RE = re.compile(r"\w+", re.UNICODE) + + def expand_query( + self, + query: str, + models: Optional[Iterable[str]] = None, + max_variants: int = 3, + ) -> List[str]: + cleaned = (query or "").strip() + if not cleaned: + return [""] + variants: List[str] = [cleaned] + lowered = cleaned.lower() + if lowered != cleaned: + variants.append(lowered) + # Strip punctuation as a tiny normalisation variant. + words = self._WORD_RE.findall(cleaned) + if words: + joined = " ".join(words) + if joined and joined not in variants: + variants.append(joined) + # Deduplicate while preserving order. + seen = set() + ordered: List[str] = [] + for item in variants: + if item in seen: + continue + seen.add(item) + ordered.append(item) + if len(ordered) >= max(1, max_variants): + break + return ordered + + def rerank( + self, + query: str, + candidates: Sequence[RerankCandidate], + top_k: Optional[int] = None, + ) -> List[RerankCandidate]: + if not candidates: + return [] + # Stable sort: highest score first. ``score`` is similarity-like; for + # distance-based stores it is up to the vector store to flip the sign. + ordered = sorted(candidates, key=lambda c: c.score, reverse=True) + if top_k is not None and top_k > 0: + ordered = ordered[:top_k] + return ordered diff --git a/src/django_graph_search/llm/factory.py b/src/django_graph_search/llm/factory.py new file mode 100644 index 0000000..aa3830d --- /dev/null +++ b/src/django_graph_search/llm/factory.py @@ -0,0 +1,37 @@ +"""Resolve the configured LLM backend from settings. + +The factory keeps the rest of the codebase unaware of concrete implementations +and lets users plug their own backend via a dotted import path. +""" +from __future__ import annotations + +from typing import Optional + +from django.utils.module_loading import import_string + +from ..exceptions import ConfigurationError +from ..settings import LLMConfig +from .base import BaseLLMBackend +from .dummy import DummyLLMBackend + + +def build_llm_backend(config: Optional[LLMConfig]) -> BaseLLMBackend: + """Instantiate an LLM backend from ``LLMConfig``. + + When no backend is configured, returns the deterministic + :class:`DummyLLMBackend` so callers can rely on a non-None object. + """ + if config is None or not config.backend: + return DummyLLMBackend() + try: + backend_cls = import_string(config.backend) + except ImportError as exc: # pragma: no cover - guarded by tests + raise ConfigurationError( + f"Cannot import LLM backend '{config.backend}': {exc}" + ) from exc + instance = backend_cls(model=config.model, **(config.options or {})) + if not isinstance(instance, BaseLLMBackend): + raise ConfigurationError( + f"LLM backend '{config.backend}' must subclass BaseLLMBackend." + ) + return instance diff --git a/src/django_graph_search/searcher.py b/src/django_graph_search/searcher.py index 8f01dfa..ffe780b 100644 --- a/src/django_graph_search/searcher.py +++ b/src/django_graph_search/searcher.py @@ -2,16 +2,31 @@ # pylint: disable=duplicate-code +import logging from typing import Iterable, List, Optional from django.apps import apps from django.urls import reverse + from .components import ComponentMixin from .graph_resolver import GraphResolver +from .llm import BaseLLMBackend, build_llm_backend from .settings import GraphSearchConfig, ModelConfig, get_settings +log = logging.getLogger(__name__) + class Searcher(ComponentMixin): + """High-level search facade. + + The public API (:meth:`search`, :meth:`find_similar`) is unchanged. When + ``GRAPH_SEARCH["LANGGRAPH"]["ENABLED"]`` is true the call is routed + through the LangGraph orchestrator defined in + :mod:`django_graph_search.langgraph_agent`; otherwise it follows the + original linear path. Either way the returned shape is identical so + callers do not need to know which path executed. + """ + def __init__( self, config: Optional[GraphSearchConfig] = None, @@ -19,6 +34,7 @@ def __init__( embedding_backend=None, resolver: Optional[GraphResolver] = None, embedding_profile: Optional[str] = None, + llm_backend: Optional[BaseLLMBackend] = None, ) -> None: self._init_components( config=config, @@ -27,6 +43,10 @@ def __init__( resolver=resolver, embedding_profile=embedding_profile, ) + self._llm_backend = llm_backend + self._compiled_graph = None # Lazy. + + # ------------------------------------------------------------------ public def search( self, @@ -35,13 +55,15 @@ def search( limit: Optional[int] = None, ) -> List[dict]: limit = limit or self.config.default_results_limit - query_vector = self.embedding_backend.embed(query) - filters = None - results = self.vector_store.search(query_vector, limit=limit, filters=filters) - if models: - allowed = set(models) - results = [item for item in results if item.metadata.get("model") in allowed] - return [self._format_result(item) for item in results] + model_list = list(models) if models else None + if self.config.langgraph.enabled: + try: + return self._search_via_graph(query, models=model_list, limit=limit) + except Exception as exc: # noqa: BLE001 + if not self.config.langgraph.fallback_on_error: + raise + log.warning("LangGraph search failed, falling back to linear path: %s", exc) + return self._search_linear(query, models=model_list, limit=limit) def find_similar( self, @@ -51,6 +73,20 @@ def find_similar( limit = limit or self.config.default_results_limit model_cfg = self._find_model_config(instance._meta.label) text = self.resolver.build_searchable_text(instance, model_cfg) + # Reuse the same graph if requested; otherwise stay on the linear path + # because instance-level similarity has historically been simpler. + if self.config.langgraph.enabled and self.config.langgraph.use_for_similar: + try: + return self._search_via_graph( + text, + models=[instance._meta.label], + limit=limit, + ) + except Exception as exc: # noqa: BLE001 + if not self.config.langgraph.fallback_on_error: + raise + log.warning("LangGraph find_similar failed, falling back: %s", exc) + query_vector = self.embedding_backend.embed(text) results = self.vector_store.search( query_vector, @@ -59,6 +95,60 @@ def find_similar( ) return [self._format_result(item) for item in results] + # ----------------------------------------------------------- legacy path + + def _search_linear( + self, + query: str, + *, + models: Optional[List[str]], + limit: int, + ) -> List[dict]: + """Original deterministic search path. Kept for backwards compatibility.""" + query_vector = self.embedding_backend.embed(query) + results = self.vector_store.search(query_vector, limit=limit, filters=None) + if models: + allowed = set(models) + results = [item for item in results if item.metadata.get("model") in allowed] + return [self._format_result(item) for item in results] + + # ---------------------------------------------------------- LangGraph path + + def _search_via_graph( + self, + query: str, + *, + models: Optional[List[str]], + limit: int, + ) -> List[dict]: + graph = self._get_or_build_graph() + state = { + "query": query, + "models": models, + "limit": limit, + "rerank_top_k": self.config.langgraph.rerank_top_k, + } + out = graph.invoke(state) + results = out.get("final_results") or [] + return [self._format_result(item) for item in results] + + def _get_or_build_graph(self): + if self._compiled_graph is not None: + return self._compiled_graph + from .langgraph_agent import resolve_graph_factory + + factory = resolve_graph_factory(self.config.langgraph.search_graph) + llm = self._llm_backend or build_llm_backend(self.config.langgraph.llm) + self._compiled_graph = factory( + self.config, + embedding_backend=self.embedding_backend, + vector_store=self.vector_store, + llm=llm, + ) + return self._compiled_graph + + # --------------------------------------------------------------- helpers + def _format_result(self, item) -> dict: model_label = item.metadata.get("model") pk = item.metadata.get("pk") @@ -92,7 +182,6 @@ def _find_model_config(self, model_label: str) -> ModelConfig: for cfg in self.config.models: if cfg.model == model_label: return cfg - # Fallback: minimal config return ModelConfig(model=model_label, fields=[], follow_relations=True) def _get_model_class(self, model_label: str): @@ -100,4 +189,3 @@ def _get_model_class(self, model_label: str): return apps.get_model(model_label) app_label, model_name = model_label.split(".", 1) return apps.get_model(app_label, model_name) - diff --git a/src/django_graph_search/settings.py b/src/django_graph_search/settings.py index da7decf..c596b64 100644 --- a/src/django_graph_search/settings.py +++ b/src/django_graph_search/settings.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from functools import lru_cache -from typing import Any, Dict, Iterable, List +from typing import Any, Dict, Iterable, List, Optional from django.conf import settings as django_settings from django.utils.module_loading import import_string @@ -36,6 +36,23 @@ "KEY_PREFIX": "dgs", "TTL": 86400, }, + "LANGGRAPH": { + "ENABLED": False, + "SEARCH_GRAPH": "django_graph_search.langgraph_agent.build_search_graph", + "USE_FOR_SIMILAR": False, + "QUERY_EXPANSION": False, + "RERANKING": False, + "MAX_EXPANDED_QUERIES": 3, + "RERANK_TOP_K": 20, + "TIMEOUT_SECONDS": 15, + "MAX_QUERY_LENGTH": 1024, + "FALLBACK_ON_ERROR": True, + "LLM": { + "BACKEND": None, + "MODEL": None, + "OPTIONS": {}, + }, + }, } @@ -69,6 +86,28 @@ class CacheConfig: ttl: int = 86400 +@dataclass(frozen=True) +class LLMConfig: + backend: Optional[str] = None + model: Optional[str] = None + options: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class LangGraphConfig: + enabled: bool = False + search_graph: str = "django_graph_search.langgraph_agent.build_search_graph" + use_for_similar: bool = False + query_expansion: bool = False + reranking: bool = False + max_expanded_queries: int = 3 + rerank_top_k: int = 20 + timeout_seconds: int = 15 + max_query_length: int = 1024 + fallback_on_error: bool = True + llm: LLMConfig = field(default_factory=LLMConfig) + + @dataclass(frozen=True) class GraphSearchConfig: models: List[ModelConfig] @@ -81,6 +120,7 @@ class GraphSearchConfig: default_results_limit: int delta_indexing: bool cache: CacheConfig + langgraph: LangGraphConfig = field(default_factory=LangGraphConfig) def _merge_dicts(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: @@ -171,6 +211,8 @@ def get_settings() -> GraphSearchConfig: ttl=int(merged["CACHE"].get("TTL", 86400)), ) + langgraph_cfg = _build_langgraph_config(merged.get("LANGGRAPH") or {}) + # Validate backend paths early _load_backend(vector_store.backend) for profile in embeddings.values(): @@ -187,5 +229,47 @@ def get_settings() -> GraphSearchConfig: default_results_limit=int(merged["DEFAULT_RESULTS_LIMIT"]), delta_indexing=bool(merged.get("DELTA_INDEXING", False)), cache=cache_cfg, + langgraph=langgraph_cfg, + ) + + +def _build_langgraph_config(payload: Dict[str, Any]) -> LangGraphConfig: + if not isinstance(payload, dict): + raise ConfigurationError("LANGGRAPH must be a dict.") + defaults = DEFAULTS["LANGGRAPH"] + merged = _merge_dicts(defaults, payload) + llm_payload = merged.get("LLM") or {} + if not isinstance(llm_payload, dict): + raise ConfigurationError("LANGGRAPH.LLM must be a dict.") + llm_cfg = LLMConfig( + backend=llm_payload.get("BACKEND"), + model=llm_payload.get("MODEL"), + options=llm_payload.get("OPTIONS", {}) or {}, + ) + max_expanded = int(merged.get("MAX_EXPANDED_QUERIES", 3)) + if max_expanded < 1: + raise ConfigurationError("LANGGRAPH.MAX_EXPANDED_QUERIES must be >= 1.") + rerank_top_k = int(merged.get("RERANK_TOP_K", 20)) + if rerank_top_k < 1: + raise ConfigurationError("LANGGRAPH.RERANK_TOP_K must be >= 1.") + timeout_seconds = int(merged.get("TIMEOUT_SECONDS", 15)) + if timeout_seconds < 1: + raise ConfigurationError("LANGGRAPH.TIMEOUT_SECONDS must be >= 1.") + max_query_length = int(merged.get("MAX_QUERY_LENGTH", 1024)) + if max_query_length < 1: + raise ConfigurationError("LANGGRAPH.MAX_QUERY_LENGTH must be >= 1.") + return LangGraphConfig( + enabled=bool(merged.get("ENABLED", False)), + search_graph=str(merged.get("SEARCH_GRAPH") + or "django_graph_search.langgraph_agent.build_search_graph"), + use_for_similar=bool(merged.get("USE_FOR_SIMILAR", False)), + query_expansion=bool(merged.get("QUERY_EXPANSION", False)), + reranking=bool(merged.get("RERANKING", False)), + max_expanded_queries=max_expanded, + rerank_top_k=rerank_top_k, + timeout_seconds=timeout_seconds, + max_query_length=max_query_length, + fallback_on_error=bool(merged.get("FALLBACK_ON_ERROR", True)), + llm=llm_cfg, ) diff --git a/tests/test_langgraph_search.py b/tests/test_langgraph_search.py new file mode 100644 index 0000000..68d264a --- /dev/null +++ b/tests/test_langgraph_search.py @@ -0,0 +1,345 @@ +"""Tests for the optional LangGraph search pipeline. + +These tests deliberately avoid pulling in the real ``langgraph`` package so +they exercise the in-tree fallback runner. They also verify that switching +``LANGGRAPH.ENABLED`` does not change observable behaviour for the simple +single-query case (Sprint 1 backwards-compat guarantee). +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence +from unittest import mock + +import pytest + +from django.conf import settings as django_settings + +from django_graph_search.langgraph_agent import ( + SearchState, + analyze_query_node, + expand_query_node, + postprocess_results_node, + rerank_results_node, + vector_search_node, +) +from django_graph_search.llm import DummyLLMBackend +from django_graph_search.llm.base import BaseLLMBackend, RerankCandidate +from django_graph_search.searcher import Searcher +from django_graph_search.settings import ( + CacheConfig, + EmbeddingProfile, + GraphSearchConfig, + LangGraphConfig, + LLMConfig, + VectorStoreConfig, + get_settings, +) + + +@pytest.fixture +def graph_search_settings(): + """Set GRAPH_SEARCH on Django settings and clear the lru_cache. + + The library reads its configuration via ``get_settings`` which is + ``lru_cache``-d. Tests that mutate ``GRAPH_SEARCH`` must reset the cache + before and after to stay isolated. + """ + original = getattr(django_settings, "GRAPH_SEARCH", None) + get_settings.cache_clear() + + def _apply(payload): + django_settings.GRAPH_SEARCH = payload + get_settings.cache_clear() + return get_settings() + + yield _apply + + if original is None: + if hasattr(django_settings, "GRAPH_SEARCH"): + delattr(django_settings, "GRAPH_SEARCH") + else: + django_settings.GRAPH_SEARCH = original + get_settings.cache_clear() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +@dataclass +class FakeHit: + """Stand-in for ``backends.base.SearchResult`` with optional text.""" + id: str + score: float + metadata: Dict[str, Any] + text: str = "" + + +class StubVectorStore: + """Vector store that returns canned hits depending on the query vector.""" + + def __init__(self, by_query: Dict[str, List[FakeHit]]): + self.by_query = by_query + self.calls: List[List[float]] = [] + + def search(self, query_vector, limit, filters=None): + self.calls.append(list(query_vector)) + # The stub embedding backend below encodes the query as a 1-D vector + # whose value is the index in the lookup map. + idx = int(query_vector[0]) if query_vector else 0 + keys = list(self.by_query.keys()) + if 0 <= idx < len(keys): + return list(self.by_query[keys[idx]])[:limit] + return [] + + +class StubEmbeddingBackend: + """Maps each unique query string to a stable index, returned as a vector.""" + + def __init__(self, queries: List[str]): + self.index = {q: i for i, q in enumerate(queries)} + + def embed(self, text: str): + return [self.index.get(text, 0)] + + def embed_batch(self, texts): + return [self.embed(t) for t in texts] + + +def _make_config(*, langgraph: LangGraphConfig) -> GraphSearchConfig: + return GraphSearchConfig( + models=[], + vector_store=VectorStoreConfig(backend="x", options={}), + embeddings={"default": EmbeddingProfile(backend="x", model_name="x")}, + default_embedding="default", + api_url_prefix="api/search/", + admin_search_enabled=False, + auto_index=False, + default_results_limit=10, + delta_indexing=False, + cache=CacheConfig(backend="file"), + langgraph=langgraph, + ) + + +# --------------------------------------------------------------------------- +# Settings +# --------------------------------------------------------------------------- + + +def test_settings_defaults_have_disabled_langgraph(graph_search_settings): + cfg = graph_search_settings({"MODELS": []}) + assert cfg.langgraph.enabled is False + assert cfg.langgraph.query_expansion is False + assert cfg.langgraph.reranking is False + assert cfg.langgraph.max_expanded_queries == 3 + + +def test_settings_validate_negative_values(graph_search_settings): + with pytest.raises(Exception): + graph_search_settings({ + "MODELS": [], + "LANGGRAPH": {"MAX_EXPANDED_QUERIES": 0}, + }) + + +# --------------------------------------------------------------------------- +# Nodes +# --------------------------------------------------------------------------- + + +def test_analyze_query_node_truncates_to_max_length(): + cfg = _make_config(langgraph=LangGraphConfig(enabled=True, max_query_length=5)) + state: SearchState = {"query": " hello world "} + out = analyze_query_node(state, config=cfg) + assert out["normalized_query"] == "hello" # 5 chars after strip+truncate. + assert out["debug"]["normalized_length"] == 5 + + +def test_dummy_backend_expands_to_at_most_n_variants(): + backend = DummyLLMBackend() + out = backend.expand_query("Hello, World!", max_variants=2) + assert out[0] == "Hello, World!" + assert len(out) <= 2 + + +def test_expand_query_node_falls_back_on_llm_failure(): + cfg = _make_config(langgraph=LangGraphConfig(query_expansion=True, max_expanded_queries=3)) + + class BoomLLM(BaseLLMBackend): + def expand_query(self, *a, **kw): + raise RuntimeError("boom") + + def rerank(self, *a, **kw): + return [] + + state: SearchState = {"normalized_query": "phone"} + out = expand_query_node(state, config=cfg, llm=BoomLLM()) + assert out["expanded_queries"] == ["phone"] + assert any("expand_query" in e for e in out.get("errors", [])) + + +def test_vector_search_node_merges_and_dedupes(): + cfg = _make_config(langgraph=LangGraphConfig(enabled=True)) + embed = StubEmbeddingBackend(["a", "b"]) + store = StubVectorStore({ + "a": [ + FakeHit("test_app.Product:1", 0.4, {"model": "test_app.Product", "pk": 1}), + FakeHit("test_app.Product:2", 0.2, {"model": "test_app.Product", "pk": 2}), + ], + "b": [ + FakeHit("test_app.Product:1", 0.9, {"model": "test_app.Product", "pk": 1}), + FakeHit("test_app.Product:3", 0.1, {"model": "test_app.Product", "pk": 3}), + ], + }) + state: SearchState = { + "expanded_queries": ["a", "b"], + "limit": 10, + "models": None, + } + out = vector_search_node(state, embedding_backend=embed, vector_store=store) + ids = [hit.id for hit in out["raw_results"]] + # Best score per id wins; ordering is by score desc. + assert ids[0] == "test_app.Product:1" + assert {hit.id for hit in out["raw_results"]} == { + "test_app.Product:1", + "test_app.Product:2", + "test_app.Product:3", + } + assert pytest.approx([h.score for h in out["raw_results"]][0]) == 0.9 + + +def test_vector_search_node_filters_by_models(): + cfg = _make_config(langgraph=LangGraphConfig(enabled=True)) + embed = StubEmbeddingBackend(["q"]) + store = StubVectorStore({ + "q": [ + FakeHit("test_app.Product:1", 0.5, {"model": "test_app.Product", "pk": 1}), + FakeHit("test_app.Tag:1", 0.6, {"model": "test_app.Tag", "pk": 1}), + ], + }) + state: SearchState = { + "expanded_queries": ["q"], + "limit": 5, + "models": ["test_app.Product"], + } + out = vector_search_node(state, embedding_backend=embed, vector_store=store) + assert {hit.metadata["model"] for hit in out["raw_results"]} == {"test_app.Product"} + + +def test_rerank_node_uses_dummy_backend_and_keeps_tail(): + cfg = _make_config(langgraph=LangGraphConfig(reranking=True, rerank_top_k=2)) + candidates = [ + FakeHit("a", 0.1, {"model": "m", "pk": 1}, text="alpha"), + FakeHit("b", 0.9, {"model": "m", "pk": 2}, text="beta"), + FakeHit("c", 0.5, {"model": "m", "pk": 3}, text="gamma"), + ] + state: SearchState = { + "merged_results": candidates, + "normalized_query": "x", + } + out = rerank_results_node(state, config=cfg, llm=DummyLLMBackend()) + # Top-2 reordered by score desc, tail untouched. + assert [c.id for c in out["reranked_results"][:2]] == ["b", "a"] + assert out["reranked_results"][2].id == "c" + + +def test_postprocess_node_applies_limit(): + state: SearchState = { + "merged_results": [FakeHit(str(i), float(i), {"model": "m", "pk": i}) for i in range(5)], + "limit": 2, + } + out = postprocess_results_node(state) + assert len(out["final_results"]) == 2 + + +# --------------------------------------------------------------------------- +# Searcher integration (uses the in-tree fallback graph runner) +# --------------------------------------------------------------------------- + + +def _make_searcher_settings(extra=None): + payload = { + "MODELS": [], + "VECTOR_STORE": {"BACKEND": "django_graph_search.backends.ChromaDBBackend"}, + "EMBEDDINGS": { + "default": { + "BACKEND": "tests.dummy_embedding_backend.DummyEmbeddingBackend", + "MODEL_NAME": "x", + } + }, + } + if extra: + payload.update(extra) + return payload + + +@pytest.mark.django_db +def test_searcher_disabled_uses_legacy_path(graph_search_settings): + graph_search_settings(_make_searcher_settings()) + embed = StubEmbeddingBackend(["foo"]) + store = StubVectorStore({"foo": [ + FakeHit("test_app.Product:1", 0.5, {"model": "test_app.Product", "pk": 1}), + ]}) + searcher = Searcher(vector_store=store, embedding_backend=embed) + results = searcher.search("foo") + assert results and results[0]["model"] == "test_app.Product" + + +@pytest.mark.django_db +def test_searcher_enabled_returns_same_shape(graph_search_settings): + graph_search_settings(_make_searcher_settings({"LANGGRAPH": {"ENABLED": True}})) + embed = StubEmbeddingBackend(["foo"]) + store = StubVectorStore({"foo": [ + FakeHit("test_app.Product:1", 0.5, {"model": "test_app.Product", "pk": 1}), + ]}) + searcher = Searcher(vector_store=store, embedding_backend=embed) + results = searcher.search("foo", limit=5) + assert results + assert set(results[0].keys()) >= {"model", "pk", "score"} + + +@pytest.mark.django_db +def test_searcher_enabled_with_expansion_runs_multi_query(graph_search_settings): + graph_search_settings(_make_searcher_settings({ + "LANGGRAPH": { + "ENABLED": True, + "QUERY_EXPANSION": True, + "MAX_EXPANDED_QUERIES": 2, + }, + })) + embed = StubEmbeddingBackend(["Hello", "hello"]) + store = StubVectorStore({ + "Hello": [FakeHit("test_app.Product:1", 0.5, {"model": "test_app.Product", "pk": 1})], + "hello": [FakeHit("test_app.Product:2", 0.7, {"model": "test_app.Product", "pk": 2})], + }) + searcher = Searcher(vector_store=store, embedding_backend=embed) + results = searcher.search("Hello", limit=5) + assert {(r["model"], r["pk"]) for r in results} >= { + ("test_app.Product", 1), + ("test_app.Product", 2), + } + + +@pytest.mark.django_db +def test_searcher_enabled_falls_back_when_graph_factory_raises(graph_search_settings): + graph_search_settings(_make_searcher_settings({ + "LANGGRAPH": { + "ENABLED": True, + "SEARCH_GRAPH": "tests.test_langgraph_search.boom_graph_factory", + "FALLBACK_ON_ERROR": True, + }, + })) + embed = StubEmbeddingBackend(["foo"]) + store = StubVectorStore({"foo": [ + FakeHit("test_app.Product:1", 0.5, {"model": "test_app.Product", "pk": 1}), + ]}) + searcher = Searcher(vector_store=store, embedding_backend=embed) + results = searcher.search("foo") + assert results + + +def boom_graph_factory(*args, **kwargs): + raise RuntimeError("boom")