diff --git a/README.md b/README.md index 4cf8774..79648cb 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,9 @@ [![MCP](https://img.shields.io/badge/MCP-1.2+-8A2BE2)](https://modelcontextprotocol.io) [![Discord](https://img.shields.io/badge/Discord-join-5865F2?logo=discord&logoColor=white)](https://discord.gg/McHmxUeS) +> ⭐ **If `ks-mcp` saves you a day of wiring up retrieval, please [star the repo](https://github.com/knowledgestack/ks-mcp/stargazers) — it's the single best signal we use to prioritize the [roadmap](#roadmap).** +> Got a tool you wish existed? [Open a feature request](https://github.com/knowledgestack/ks-mcp/issues/new?template=feature_request.yml). Want a working example? See the [`ks-cookbook`](https://github.com/knowledgestack/ks-cookbook). + --- ## Table of contents @@ -32,7 +35,9 @@ - [pydantic-ai](#pydantic-ai) - [LangGraph](#langgraph) - [OpenAI Agents SDK](#openai-agents-sdk) -- [Tools (v1 — read-only)](#tools-v1--read-only) +- [Tools](#tools) +- [How the tools fit together](#how-the-tools-fit-together) +- [Examples & cookbooks](#examples--cookbooks) - [Transports](#transports) - [Security model](#security-model) - [Diagnostics](#diagnostics) @@ -56,7 +61,7 @@ Most agent frameworks ship their own "retrieval toolbox" the moment you need to **Key properties:** -- **Read-only by design.** v1 deliberately exposes no mutating tools. Safe to connect to production knowledge bases. +- **Mostly read-only.** Every tool is read-only except `ask`, which posts a user message to a (newly-created or reused) thread so the KS agent can stream a grounded answer back. There is no ingest / delete surface in v1. - **Tenant-scoped.** Every call is authenticated with a per-user API key; nothing crosses tenant boundaries. - **Grounded.** Every search result and `read` payload returns stable chunk IDs + path parts you can cite. - **Two transports.** Local stdio for desktop agents; Streamable HTTP for remote / multi-agent deployments. @@ -212,10 +217,12 @@ agent = Agent(name="Research", mcp_servers=[server]) | Tool | Description | | --- | --- | +| `ask` | One-shot grounded Q&A: dispatches to the KS agent, streams the assistant reply, and returns assembled text + citations. | | `search_knowledge` | Semantic (dense-vector) chunk search over the tenant corpus. | | `search_keyword` | BM25 chunk search for exact terminology and identifiers. | -| `read` | Read a folder / document / section / chunk by path or ID. | +| `read` | Read a folder / document / section / chunk by `path_part_id` (also accepts a `chunk_id` directly). | | `read_around` | Fetch the N chunks before and after an anchor chunk for context expansion. | +| `cite` | Build a structured citation (document, path, page, snippet, `[chunk:UUID]` tag) for one chunk. | | `list_contents` | List children of a folder (like `ls`). | | `find` | Fuzzy-match a path-part by name when you don't know the exact path. | | `get_info` | Path-part metadata + ancestry breadcrumb, for citations. | @@ -242,6 +249,68 @@ agent = Agent(name="Research", mcp_servers=[server]) Writes (ingest / delete / generate) are intentionally **not** exposed in Phase 1 or 2. See the [Roadmap](#roadmap) for the plan around admin-scoped write tools. +## How the tools fit together + +You have **two paths** to a grounded answer. Pick the one that fits the agent +you're building. + +```text + ┌────────────── Quick path: one-shot grounded Q&A ──────────────┐ + │ │ + user q. ──►│ ask(question, [thread_id]) │──► answer + citations + │ (KS agent does the retrieval + drafting; you just ship it.) │ + └────────────────────────────────────────────────────────────────┘ + + ┌────────────── Custom path: roll your own loop ────────────────┐ + │ │ + │ search_knowledge / search_keyword │ + │ │ │ + │ ▼ chunk_id, materialized_path │ + │ read_around(chunk_id) · read(chunk_id|pp_id) · view_chunk_image + │ │ │ + │ ▼ │ + │ cite(chunk_id) → [chunk:UUID] tag + structured footnote │ + │ │ │ + │ ▼ │ + │ you assemble the answer │ + └────────────────────────────────────────────────────────────────┘ +``` + +`ask` is the right choice when you want one tool call to do the whole job. +The custom path is right when you need to weave multiple chunks across +documents, when you want to control the prompt, or when you're building a +multi-step agent that interleaves retrieval with other tools. + +Side tools — `list_contents`, `find`, `get_info` — exist for navigation when +the user asks about a specific document by name or wants you to walk a folder +tree. `trace_chunk_lineage` and `compare_versions` answer "where did this +evidence come from?" once you already have a chunk in hand. + +**Identifier cheat sheet** + +| Field | Source | Use it with | +| --- | --- | --- | +| `chunk_id` | `search_*` hits, `read` output, neighbour `[chunk:UUID]` tags | `cite`, `read_around`, `view_chunk_image`, `read` (fallback path) | +| `path_part_id` | `list_contents`, `find`, `get_info`, search hits | `read`, `get_info`, `list_contents`, search filters | +| `materialized_path` | every chunk / path-part response | display only — never use as an id | + +`chunk_id` and `path_part_id` look identical (both are UUIDs) but are +**different objects**. When in doubt, pass it to `read` — it accepts either. + +## Examples & cookbooks + +End-to-end, citation-grounded examples live in **[`ks-cookbook`](https://github.com/knowledgestack/ks-cookbook)** — every recipe drives this MCP server (stdio plumbing, real `[chunk:UUID]` citations) from a working agent. The cookbook organises recipes by domain; some categories you'll find: + +- **Sales / RevOps** — account research, ICP matching, deal-loss retros, churn risk evidence. +- **Legal / Privacy** — NDA review, DPA gap checks, clause extraction, data-subject request responder. +- **Healthcare** — discharge summary rewrite, drug-interaction checker, audit-defensible HCC coder. +- **Finance & risk** — Basel III risk weighting, AML/SAR narrative drafting, cash-flow anomaly detection. +- **Engineering ops** — ADR drafter, changelog from commits, API deprecation notices, change-monitor → PR. + +Browse the full list in [`recipes/INDEX.md`](https://github.com/knowledgestack/ks-cookbook/blob/main/recipes/INDEX.md) (and the longer-form **flagships/** directory for multi-step agents). + +If you build something interesting on top of `ks-mcp`, please [open a PR against `ks-cookbook`](https://github.com/knowledgestack/ks-cookbook/pulls) — we feature community recipes on the cookbook front page. + ## Transports | Transport | When to use | Command | @@ -293,14 +362,17 @@ src/ks_mcp/ ├── errors.py # typed error mapping └── tools/ ├── search.py # search_knowledge, search_keyword - ├── read.py # read, read_around - ├── browse.py # list_contents, find, get_info, view_chunk_image - └── org.py # get_organization_info, get_current_datetime + ├── read.py # read, read_around, view_chunk_image + ├── cite.py # cite (structured citation builder) + ├── ask.py # ask (one-shot agent Q&A over SSE) + ├── browse.py # list_contents, find, get_info + ├── org.py # get_organization_info, get_current_datetime + └── provenance.py # trace_chunk_lineage, compare_versions ``` ## Roadmap -See [ROADMAP.md](ROADMAP.md) and the [public issue tracker](https://github.com/knowledgestack/ks-mcp/issues) for everything on deck. Highlights for the next few releases: +See [ROADMAP.md](ROADMAP.md) and the [public issue tracker](https://github.com/knowledgestack/ks-mcp/issues) for everything on deck. **We prioritize what users thumbs-up** — if a milestone matters to you, react on the issue. - **v0.2** — OAuth 2.1 device flow auth, resource templates for folders/documents, streaming partial results, prompt library. - **v0.3** — admin-scoped **write tools** behind an explicit opt-in flag (`--allow-write`): ingest, delete, re-embed. @@ -308,11 +380,15 @@ See [ROADMAP.md](ROADMAP.md) and the [public issue tracker](https://github.com/k - **v0.5** — hybrid search (dense + BM25 fusion) tool, and a `summarize_document` convenience tool. - **v1.0** — stable tool surface, semver guarantees, registry listing on [github.com/mcp](https://github.com/mcp). -Want to influence it? Thumbs-up the issues you care about, or open a [feature request](https://github.com/knowledgestack/ks-mcp/issues/new?template=feature_request.yml). +Three ways to influence the roadmap: + +1. ⭐ **[Star the repo](https://github.com/knowledgestack/ks-mcp/stargazers)** — stars are how we justify investment in this surface. +2. 👍 **Thumbs-up issues** in the [tracker](https://github.com/knowledgestack/ks-mcp/issues) — we sort by reactions when picking the next milestone. +3. ✨ **[Open a feature request](https://github.com/knowledgestack/ks-mcp/issues/new?template=feature_request.yml)** — concrete use cases beat abstract wishlists. ## Related repos -- **[ks-cookbook](https://github.com/knowledgestack/ks-cookbook)** — 32 production-style agent flagships built on this server. +- **[ks-cookbook](https://github.com/knowledgestack/ks-cookbook)** — production-style agent flagships built on this server (start here for working code). - **[ks-sdk-python](https://github.com/knowledgestack/ks-sdk-python)** — Python SDK (`ksapi` on PyPI) for admin / write operations. - **[ks-sdk-ts](https://github.com/knowledgestack/ks-sdk-ts)** — TypeScript SDK (`@knowledge-stack/ksapi` on npm). - **[ks-docs](https://github.com/knowledgestack/ks-docs)** — central developer docs (Mintlify → docs.knowledgestack.ai). @@ -323,6 +399,11 @@ Issues and PRs welcome. Please read [SECURITY.md](SECURITY.md) before reporting Development happens in the open on `main`; feature branches land via PR with CI (pytest + ruff + pyright) required to pass. +**Two quick ways to help, even if you can't open a PR:** + +- ⭐ **Star** the repo — it directly shapes our investment. +- 💬 Drop a note on [Discord](https://discord.gg/McHmxUeS) telling us what you're building. We frequently turn user stories into cookbook recipes. + ## License MIT — see [LICENSE](LICENSE). diff --git a/ROADMAP.md b/ROADMAP.md index 11d9a86..981b3f1 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -1,6 +1,8 @@ # Roadmap -`ks-mcp` is in active development. This is our public plan — each item below is tracked as a GitHub issue, grouped into a milestone. Thumbs-up the ones that matter most to you; we prioritize community signal. +`ks-mcp` is in active development. This is our public plan — each item below is tracked as a GitHub issue, grouped into a milestone. + +> 👍 **Thumbs-up the issues that matter most to you.** We sort milestones by community signal — your reaction directly moves things up the queue. ⭐ Starring the repo helps too. > **Legend:** 🟢 shipped · 🟡 in progress · ⚪ planned @@ -8,11 +10,14 @@ Read-only tool surface over the Knowledge Stack API. Shipped. +- 🟢 `ask` — one-shot grounded Q&A via the KS agent (SSE assembled into a single result) - 🟢 `search_knowledge`, `search_keyword` - 🟢 `read`, `read_around` +- 🟢 `cite` — structured citations with page-aware footnotes - 🟢 `list_contents`, `find`, `get_info` - 🟢 `view_chunk_image` - 🟢 `get_organization_info`, `get_current_datetime` +- 🟢 `trace_chunk_lineage`, `compare_versions` (provenance) - 🟢 stdio + Streamable HTTP transports ## v0.2 — Registry & auth 🟡 (target: Q2 2026) diff --git a/src/ks_mcp/client.py b/src/ks_mcp/client.py index 37fa957..2d7be0e 100644 --- a/src/ks_mcp/client.py +++ b/src/ks_mcp/client.py @@ -5,7 +5,6 @@ ``KS_BASE_URL`` from the environment. """ - import os from functools import lru_cache @@ -15,9 +14,7 @@ def _env(key: str, *, required: bool = True, default: str | None = None) -> str: value = os.environ.get(key, default) if required and not value: - raise RuntimeError( - f"{key} is not set. Export it or add it to your MCP client config." - ) + raise RuntimeError(f"{key} is not set. Export it or add it to your MCP client config.") return value or "" diff --git a/src/ks_mcp/errors.py b/src/ks_mcp/errors.py index 1030fee..b361f23 100644 --- a/src/ks_mcp/errors.py +++ b/src/ks_mcp/errors.py @@ -1,6 +1,5 @@ """Map ``ksapi.ApiException`` onto MCP-friendly responses.""" - import ksapi from mcp import McpError from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, ErrorData @@ -25,10 +24,7 @@ def rest_to_mcp(exc: ksapi.ApiException) -> McpError: msg = f"Forbidden for this path (403). {snippet}" return McpError(ErrorData(code=INVALID_PARAMS, message=msg)) if status >= 500: - msg = ( - f"KS backend error ({status}). Transient; retry later, do not auto-loop. " - f"{snippet}" - ) + msg = f"KS backend error ({status}). Transient; retry later, do not auto-loop. {snippet}" return McpError(ErrorData(code=INTERNAL_ERROR, message=msg)) return McpError( diff --git a/src/ks_mcp/schema.py b/src/ks_mcp/schema.py index 4368ffa..ccdef9a 100644 --- a/src/ks_mcp/schema.py +++ b/src/ks_mcp/schema.py @@ -5,7 +5,6 @@ per-argument documentation. Good descriptions here == good tool use. """ - from enum import Enum from uuid import UUID @@ -26,14 +25,33 @@ class ChunkType(str, Enum): class ChunkHit(BaseModel): - chunk_id: UUID + chunk_id: UUID = Field( + ..., + description=( + "Stable id for this chunk. Use it with `cite`, `read_around`, `view_chunk_image`, " + "or `read` (which now also accepts a chunk_id). NOT the same as `path_part_id`." + ), + ) document_name: str = Field(..., description="Name of the owning document.") text: str = Field(..., description="Raw text of the chunk.") score: float | None = Field( - None, description="Relevance score from the search backend (higher = better)." + default=None, description="Relevance score from the search backend (higher = better)." ) chunk_type: ChunkType = Field(default=ChunkType.TEXT) - path_part_id: UUID | None = None + path_part_id: UUID | None = Field( + default=None, + description=( + "PDO id of the chunk node in the path tree. Pass to `get_info` for ancestry; " + "do NOT confuse with `chunk_id` — they are different UUIDs." + ), + ) + materialized_path: str | None = Field( + default=None, + description=( + "Full root-to-leaf path of the chunk (e.g. `Folder/SubFolder/Document/Section/...`). " + "Use this when displaying citations or grouping hits by document." + ), + ) class PathPartInfo(BaseModel): @@ -51,10 +69,13 @@ class PathPartInfo(BaseModel): class SearchInput(BaseModel): - query: str = Field(..., min_length=1, max_length=4_000, - description="Natural-language query for semantic search, or keyword phrase for BM25.") - top_k: int = Field(default=5, ge=1, le=50, - description="Maximum number of chunks to return.") + query: str = Field( + ..., + min_length=1, + max_length=4_000, + description="Natural-language query for semantic search, or keyword phrase for BM25.", + ) + top_k: int = Field(default=5, ge=1, le=50, description="Maximum number of chunks to return.") parent_path_part_ids: list[UUID] | None = Field( default=None, description="Restrict search to descendants of these path-parts. None = whole tenant.", @@ -63,22 +84,20 @@ class SearchInput(BaseModel): default=None, description="Only include chunks whose owning document carries all these tags.", ) - distinct_files: bool = Field( - default=False, - description="If True, return at most one chunk per source document.", - ) class ReadInput(BaseModel): path_part_id: UUID = Field( ..., description=( - "Any PDO identifier — folder, document, section, or chunk. The tool " - "dispatches to the appropriate read path." + "Any PDO identifier — folder, document, section, or chunk path-part. The tool " + "dispatches on part type, and falls back to fetching as a chunk_id on 404." ), ) max_chars: int = Field( - default=4_000, ge=100, le=50_000, + default=4_000, + ge=100, + le=50_000, description="Truncate returned text to this many characters.", ) @@ -86,7 +105,9 @@ class ReadInput(BaseModel): class ReadAroundInput(BaseModel): chunk_id: UUID = Field(..., description="Anchor chunk.") radius: int = Field( - default=2, ge=0, le=10, + default=2, + ge=0, + le=10, description="Number of chunks before AND after the anchor to include.", ) @@ -99,8 +120,12 @@ class ListContentsInput(BaseModel): class FindInput(BaseModel): - query: str = Field(..., min_length=1, max_length=256, - description="Fuzzy-match substring of the path-part's name.") + query: str = Field( + ..., + min_length=1, + max_length=256, + description="Fuzzy-match substring of the path-part's name.", + ) parent_path_part_ids: list[UUID] | None = Field( default=None, description="Restrict to descendants of these path-parts.", @@ -135,15 +160,85 @@ class PathPartAncestry(BaseModel): class OrganizationInfo(BaseModel): tenant_id: UUID name: str - default_language: str = Field( - ..., description="ISO-639 code, e.g. 'en'." - ) + default_language: str = Field(..., description="ISO-639 code, e.g. 'en'.") timezone: str = Field(..., description="IANA tz, e.g. 'America/New_York'.") class CurrentDateTime(BaseModel): iso_utc: str = Field(..., description="Current time in UTC, ISO-8601.") - iso_local: str = Field( - ..., description="Current time in the tenant's timezone, ISO-8601." - ) + iso_local: str = Field(..., description="Current time in the tenant's timezone, ISO-8601.") timezone: str = Field(..., description="IANA tz used for ``iso_local``.") + + +class Citation(BaseModel): + """Compact citation payload for a single chunk. + + Designed to be appended verbatim to an agent's prose answer. The ``tag`` + field is a stable inline reference (``[chunk:UUID]``); ``snippet`` is a + short excerpt suitable for tooltips or footnotes; ``materialized_path`` + + ``page_number`` give a human-readable source location. + """ + + chunk_id: UUID + document_name: str = Field(..., description="Name of the owning document.") + materialized_path: str | None = Field( + default=None, + description="Root-to-leaf path of the chunk (e.g. ``Handbook/Onboarding/Section 2/...``).", + ) + page_number: int | None = Field( + default=None, + description=( + "Page number of the chunk's nearest SECTION ancestor, when available. " + "May be None for non-paginated documents (web pages, plain text)." + ), + ) + snippet: str = Field( + ..., description="Up to ~240 chars of the chunk text, suitable for a footnote." + ) + tag: str = Field( + ..., + description=( + "Inline reference token: ``[chunk:UUID]``. Append to the sentence in the agent's " + "answer that the chunk supports." + ), + ) + + +class AskCitation(BaseModel): + """Citation surfaced inline by the KS agent during an ``ask`` call. + + Lighter than ``Citation`` — the agent emits these directly, so we expose + only what the streaming protocol provides without an extra round-trip. + """ + + chunk_id: UUID + quote: str = Field(..., description="The quoted text from the chunk.") + document_id: UUID | None = None + document_name: str | None = None + materialized_path: str | None = None + page_number: int | None = None + + +class AskResult(BaseModel): + """Final assistant answer assembled from a thread streaming run.""" + + answer: str = Field(..., description="The agent's final assistant message text.") + citations: list[AskCitation] = Field( + default_factory=list, + description="Citations the agent emitted while answering. May be empty.", + ) + thread_id: UUID = Field( + ..., description="Thread the conversation lives on (reuse for follow-ups)." + ) + message_id: UUID | None = Field( + default=None, + description="Assistant message id (use with `read_around`/`cite` if you need more context).", + ) + workflow_id: str | None = Field( + default=None, + description="Underlying agent workflow id; surfaced for debugging and audit.", + ) + is_error: bool = Field( + default=False, + description="True if the agent hit a terminal error mid-stream; ``answer`` then contains the error message.", + ) diff --git a/src/ks_mcp/server.py b/src/ks_mcp/server.py index 6534ae6..a076b61 100644 --- a/src/ks_mcp/server.py +++ b/src/ks_mcp/server.py @@ -13,33 +13,46 @@ KS_BASE_URL (optional) Override the KS API host (default: production). """ - import argparse from mcp.server.fastmcp import FastMCP -from ks_mcp.tools import browse, org, provenance, read, search +from ks_mcp.tools import ask, browse, cite, org, provenance, read, search def build_server(host: str = "127.0.0.1", port: int = 8765) -> FastMCP: mcp = FastMCP( name="knowledgestack", instructions=( - "Knowledge Stack: ground every answer in the tenant's knowledge base. " - "Prefer `search_knowledge` for conceptual questions and `search_keyword` " - "for exact terms; use `read`/`read_around` to pull full text with citations." + "Knowledge Stack: ground every answer in the tenant's knowledge base.\n" + "Two ways to use this server:\n" + " • Quick path — call `ask(question)` for a one-shot grounded answer " + "from the KS agent (streamed assistant reply assembled into one result, " + "with inline citations).\n" + " • Custom path — `search_knowledge` (concepts) or `search_keyword` " + "(exact terms) → `read_around` / `read` for context → `cite` once per " + "chunk used. Append the returned `[chunk:UUID]` tag to the supporting " + "sentence in your answer.\n" + "Use `list_contents` / `find` / `get_info` to navigate the folder tree, " + "and `view_chunk_image` for IMAGE-type chunks. The only state-changing " + "tool is `ask`, which creates a thread and assistant message; everything " + "else is read-only." ), host=host, port=port, ) - for module in (search, read, browse, org, provenance): + for module in (search, read, browse, cite, ask, org, provenance): module.register(mcp) return mcp def main() -> None: - parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument("--http", action="store_true", help="Serve over Streamable HTTP instead of stdio.") + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument( + "--http", action="store_true", help="Serve over Streamable HTTP instead of stdio." + ) parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", type=int, default=8765) args = parser.parse_args() diff --git a/src/ks_mcp/tools/ask.py b/src/ks_mcp/tools/ask.py new file mode 100644 index 0000000..3401a38 --- /dev/null +++ b/src/ks_mcp/tools/ask.py @@ -0,0 +1,235 @@ +"""Ask the Knowledge Stack agent a question and return the final answer.""" + +import json +import os +from typing import Annotated, Any +from uuid import UUID + +import httpx +import ksapi +from ksapi.api.threads_api import ThreadsApi +from mcp.server.fastmcp import FastMCP +from pydantic import Field + +from ks_mcp.client import get_api_client +from ks_mcp.errors import rest_to_mcp +from ks_mcp.schema import AskCitation, AskResult + +_DEFAULT_TIMEOUT_S = 120.0 +_SSE_READ_TIMEOUT_S = 180.0 + + +def _ensure_thread(client: ksapi.ApiClient, thread_id: UUID | None, question: str) -> UUID: + """Return an existing thread id, or create a new thread auto-titled by ``question``.""" + if thread_id is not None: + return thread_id + create_req = ksapi.CreateThreadRequest(message_for_title=question[:4000]) + try: + thread: Any = ThreadsApi(client).create_thread(create_thread_request=create_req) + except ksapi.ApiException as exc: + raise rest_to_mcp(exc) from exc + return thread.id + + +def _send_user_message( + client: ksapi.ApiClient, + thread_id: UUID, + question: str, +) -> str: + """POST /threads/{id}/user_message → returns workflow_id (202).""" + req = ksapi.UserMessageRequest(input_text=question) + try: + resp: Any = ThreadsApi(client).send_user_message( + thread_id=thread_id, + user_message_request=req, + ) + except ksapi.ApiException as exc: + raise rest_to_mcp(exc) from exc + return str(getattr(resp, "workflow_id", "") or "") + + +def _parse_sse_block(block: str) -> tuple[str | None, str]: + """Pull (event, data) out of one SSE event block. Returns (None, "") on garbage.""" + event: str | None = None + data_lines: list[str] = [] + for raw in block.splitlines(): + if not raw or raw.startswith(":"): + continue + if ":" not in raw: + continue + field, _, value = raw.partition(":") + value = value.lstrip(" ") + if field == "event": + event = value + elif field == "data": + data_lines.append(value) + return event, "\n".join(data_lines) + + +def _stream_answer( + base_url: str, + api_key: str, + thread_id: UUID, + timeout_s: float, +) -> AskResult: + """Open the SSE stream, accumulate text deltas, return the assembled answer.""" + parts: list[str] = [] + citations: list[AskCitation] = [] + message_id: UUID | None = None + is_error = False + error_text = "" + + headers = { + "Authorization": f"Bearer {api_key}", + "Accept": "text/event-stream", + } + url = f"{base_url.rstrip('/')}/v1/threads/{thread_id}/stream" + + with httpx.Client(timeout=httpx.Timeout(timeout_s, read=_SSE_READ_TIMEOUT_S)) as http: + with http.stream("GET", url, headers=headers) as resp: + if resp.status_code != 200: + resp.read() + raise RuntimeError( + f"Stream returned HTTP {resp.status_code}: " + f"{resp.text[:300] if resp.text else ''}" + ) + buffer = "" + for chunk in resp.iter_text(): + if not chunk: + continue + buffer += chunk + while "\n\n" in buffer: + block, buffer = buffer.split("\n\n", 1) + event, data = _parse_sse_block(block) + if event is None and data == "[DONE]": + return _build_result( + parts, + citations, + thread_id, + message_id, + is_error, + error_text, + ) + if not event: + continue + payload: dict[str, Any] + try: + payload = json.loads(data) if data else {} + except json.JSONDecodeError: + continue + if event == "message_start": + raw_id = payload.get("id") + if raw_id: + try: + message_id = UUID(str(raw_id)) + except ValueError: + pass + elif event == "text_delta": + delta = payload.get("delta") + if isinstance(delta, str): + parts.append(delta) + elif event == "citations": + for raw in payload.get("citations") or []: + try: + citations.append(_to_ask_citation(raw)) + except (KeyError, ValueError, TypeError): + continue + elif event == "error": + is_error = True + error_text = str(payload.get("error", "")) or "agent error" + elif event == "message_end": + return _build_result( + parts, + citations, + thread_id, + message_id, + is_error, + error_text, + ) + return _build_result(parts, citations, thread_id, message_id, is_error, error_text) + + +def _to_ask_citation(raw: dict[str, Any]) -> AskCitation: + return AskCitation( + chunk_id=UUID(str(raw["chunk_id"])), + quote=str(raw.get("quote", "")), + document_id=UUID(str(raw["document_id"])) if raw.get("document_id") else None, + document_name=raw.get("document_name"), + materialized_path=raw.get("materialized_path"), + page_number=raw.get("page_number"), + ) + + +def _build_result( + parts: list[str], + citations: list[AskCitation], + thread_id: UUID, + message_id: UUID | None, + is_error: bool, + error_text: str, +) -> AskResult: + answer = "".join(parts).strip() + if is_error and not answer: + answer = f"(agent error) {error_text}" + return AskResult( + answer=answer or "(empty answer)", + citations=citations, + thread_id=thread_id, + message_id=message_id, + is_error=is_error, + ) + + +def register(mcp: FastMCP) -> None: + @mcp.tool() + def ask( + question: Annotated[ + str, + Field( + description="Natural-language question to send to the KS agent.", + min_length=1, + max_length=8000, + ), + ], + thread_id: Annotated[ + UUID | None, + Field( + description=( + "Reuse an existing thread for multi-turn follow-ups. " + "Omit to start a fresh thread, auto-titled from ``question``." + ), + ), + ] = None, + timeout_s: Annotated[ + float, + Field( + description="Hard ceiling on the streaming wait. Beyond this we return whatever was assembled.", + ge=10.0, + le=600.0, + ), + ] = _DEFAULT_TIMEOUT_S, + ) -> AskResult: + """Ask the Knowledge Stack agent a question and return the final answer. + + Wraps the backend's two-step ask flow (POST user message → SSE stream) + into a single synchronous tool call: assembles the streamed text, + captures any inline citations, and returns once the agent emits + ``message_end``. Use this when you want one grounded answer rather + than running your own retrieval loop. + + Returns ``AskResult(answer, citations[], thread_id, message_id, ...)``. + Pass the same ``thread_id`` back on a follow-up call to continue the + conversation. For raw retrieval (without an LLM), use + ``search_knowledge`` / ``search_keyword`` + ``cite`` instead. + """ + client = get_api_client() + thread = _ensure_thread(client, thread_id, question) + workflow_id = _send_user_message(client, thread, question) + + api_key = os.environ.get("KS_API_KEY", "") + base_url = os.environ.get("KS_BASE_URL", "https://api.knowledgestack.ai") + + result = _stream_answer(base_url, api_key, thread, timeout_s) + if workflow_id and not result.workflow_id: + result = result.model_copy(update={"workflow_id": workflow_id}) + return result diff --git a/src/ks_mcp/tools/browse.py b/src/ks_mcp/tools/browse.py index 6c6d07b..984d280 100644 --- a/src/ks_mcp/tools/browse.py +++ b/src/ks_mcp/tools/browse.py @@ -1,6 +1,5 @@ """Folder-tree navigation: list children, fuzzy-find by name, introspect a node.""" - from typing import Annotated, Any from uuid import UUID @@ -17,9 +16,7 @@ def _pp_info(pp: Any) -> PathPartInfo | None: inner = getattr(pp, "actual_instance", None) or pp - path_part_id = ( - getattr(inner, "path_part_id", None) or getattr(inner, "id", None) - ) + path_part_id = getattr(inner, "path_part_id", None) or getattr(inner, "id", None) if path_part_id is None: return None return PathPartInfo( @@ -30,9 +27,9 @@ def _pp_info(pp: Any) -> PathPartInfo | None: ) -def _filter_pp_infos(items: list[Any]) -> list[PathPartInfo]: +def _filter_pp_infos(items: Any) -> list[PathPartInfo]: out: list[PathPartInfo] = [] - for i in items: + for i in items or []: info = _pp_info(i) if info is not None: out.append(info) @@ -54,7 +51,10 @@ def _resolve_folder_id(client: Any, folder_id: UUID) -> UUID: return folder_id raise rest_to_mcp(exc) from exc - if str(getattr(pp, "part_type", "")) != "FOLDER": + part_type = str(getattr(pp, "part_type", "")) + if part_type.startswith("PartType."): + part_type = part_type.removeprefix("PartType.") + if part_type != "FOLDER": return folder_id metadata_obj_id = getattr(pp, "metadata_obj_id", None) @@ -66,13 +66,22 @@ def register(mcp: FastMCP) -> None: def list_contents( folder_id: Annotated[ UUID | None, - Field(description="Folder PDO id. Omit to list root-level folders in the tenant."), + Field( + description=( + "Folder PDO id (or its path_part_id — both are accepted). " + "Omit to list root-level folders in the tenant." + ), + ), ] = None, ) -> list[PathPartInfo]: - """List the immediate children of a folder. + """List the immediate children of a folder, like ``ls``. Pass no argument to list root-level folders. Returns one entry per - child (folder or document) with its path-part id, name, and type. + child (folder or document) with its ``path_part_id``, ``name``, + ``part_type``, and ``materialized_path``. If a stale / unknown + ``folder_id`` is supplied, the tool falls back to listing the tenant + root rather than failing — agents discovering the corpus get a useful + result instead of a dead-end 404. """ client = get_api_client() folders = FoldersApi(client) @@ -95,22 +104,30 @@ def list_contents( @mcp.tool() def find( - query: Annotated[str, Field(description="Fuzzy substring of the path-part's name.", min_length=1, max_length=255)], + query: Annotated[ + str, + Field( + description="Fuzzy substring of the path-part's name.", min_length=1, max_length=255 + ), + ], parent_path_part_id: Annotated[ UUID | None, - Field(description="Restrict search to descendants of this folder. Omit for whole tenant."), + Field( + description="Restrict search to descendants of this folder. Omit for whole tenant." + ), ] = None, ) -> list[PathPartInfo]: """Fuzzy-search path-parts (folders, documents, sections) by name. - Use when the user refers to a document by a remembered title fragment. + Use when the user refers to a document by a remembered title fragment + ("the onboarding handbook", "Q3 forecast"). For matching the *body* of + a document, use ``search_keyword`` instead — ``find`` only looks at + names. """ client = get_api_client() folders = FoldersApi(client) try: - result = folders.search_items( - name_like=query, parent_path_part_id=parent_path_part_id - ) + result = folders.search_items(name_like=query, parent_path_part_id=parent_path_part_id) except ksapi.ApiException as exc: raise rest_to_mcp(exc) from exc items = getattr(result, "items", None) or result or [] @@ -118,12 +135,16 @@ def find( @mcp.tool() def get_info( - path_part_id: Annotated[UUID, Field(description="Any PDO id — folder, document, section, or chunk.")], + path_part_id: Annotated[ + UUID, Field(description="Any PDO id — folder, document, section, or chunk.") + ], ) -> PathPartAncestry: """Return a path-part's own info plus its root-to-leaf ancestry breadcrumb. Use when you need to resolve a node's type or build a human-readable - path before calling ``read``. + path before calling ``read`` or ``cite``. The returned ``ancestry`` + list is ordered root → … → parent (excluding the node itself), so the + last element is the immediate parent. """ client = get_api_client() api = PathPartsApi(client) @@ -136,9 +157,7 @@ def get_info( raise rest_to_mcp(exc) from exc ancestry_items = ( - getattr(ancestry_resp, "ancestors", None) - or getattr(ancestry_resp, "items", None) - or [] + getattr(ancestry_resp, "ancestors", None) or getattr(ancestry_resp, "items", None) or [] ) node_info = _pp_info(node) if node_info is None: diff --git a/src/ks_mcp/tools/cite.py b/src/ks_mcp/tools/cite.py new file mode 100644 index 0000000..9ba2b57 --- /dev/null +++ b/src/ks_mcp/tools/cite.py @@ -0,0 +1,117 @@ +"""Build a structured citation for a chunk_id, ready to drop into an answer.""" + +from typing import Annotated, Any +from uuid import UUID + +import ksapi +from ksapi.api.chunks_api import ChunksApi +from ksapi.api.path_parts_api import PathPartsApi +from ksapi.api.sections_api import SectionsApi +from mcp.server.fastmcp import FastMCP +from pydantic import Field + +from ks_mcp.client import get_api_client +from ks_mcp.errors import rest_to_mcp +from ks_mcp.schema import Citation + +_SNIPPET_CHARS = 240 + + +def _snippet(text: str, limit: int = _SNIPPET_CHARS) -> str: + text = (text or "").strip() + if len(text) <= limit: + return text + # Try to cut at a word boundary so the snippet reads cleanly. + cut = text.rfind(" ", 0, limit) + if cut < int(limit * 0.6): + cut = limit + return text[:cut].rstrip() + "…" + + +def _normalize_part_type(value: Any) -> str: + text = str(value or "") + if text.startswith("PartType."): + text = text.removeprefix("PartType.") + return text + + +def _page_number_from_ancestry(client: ksapi.ApiClient, path_part_id: UUID | None) -> int | None: + """Walk root→leaf ancestry and return the nearest SECTION's page_number, if any.""" + if path_part_id is None: + return None + try: + ancestry: Any = PathPartsApi(client).get_path_part_ancestry(path_part_id=path_part_id) + except ksapi.ApiException: + return None + items = getattr(ancestry, "ancestors", None) or getattr(ancestry, "items", None) or [] + section_metadata_id: UUID | None = None + # Pick the deepest (closest-to-leaf) SECTION ancestor. + for item in items: + inner = getattr(item, "actual_instance", None) or item + if _normalize_part_type(getattr(inner, "part_type", "")) == "SECTION": + candidate = getattr(inner, "metadata_obj_id", None) + if candidate is not None: + section_metadata_id = candidate + if section_metadata_id is None: + return None + try: + section: Any = SectionsApi(client).get_section(section_id=section_metadata_id) + except ksapi.ApiException: + return None + page = getattr(section, "page_number", None) + return int(page) if isinstance(page, int) else None + + +def register(mcp: FastMCP) -> None: + @mcp.tool() + def cite( + chunk_id: Annotated[ + UUID, + Field( + description=( + "Chunk id to build a citation for — typically the ``chunk_id`` from a " + "``search_knowledge`` / ``search_keyword`` hit. NOT a ``path_part_id``." + ), + ), + ], + ) -> Citation: + """Build a structured citation for a single chunk. + + Returns ``document_name``, ``materialized_path``, ``page_number`` (when + the chunk lives under a paginated SECTION), a short ``snippet``, and a + stable ``tag`` string (``[chunk:UUID]``) suitable for inline use in the + agent's prose. + + Recommended usage: call ``cite`` once per chunk that backed an answer, + then append the ``tag`` to the relevant sentence and surface the rest + (document, path, page, snippet) in a footnote / sources block. + """ + client = get_api_client() + try: + # `with_document=True` is required for `chunk.document.name` to be + # populated — otherwise every citation comes back document-less. + chunk: Any = ChunksApi(client).get_chunk( + chunk_id=chunk_id, + with_document=True, + ) + except ksapi.ApiException as exc: + raise rest_to_mcp(exc) from exc + + document = getattr(chunk, "document", None) + document_name = ( + getattr(document, "name", None) + or getattr(chunk, "document_name", None) + or "Untitled document" + ) + materialized_path = getattr(chunk, "materialized_path", None) + body = getattr(chunk, "content", None) or getattr(chunk, "text", "") or "" + page = _page_number_from_ancestry(client, getattr(chunk, "path_part_id", None)) + + return Citation( + chunk_id=chunk_id, + document_name=str(document_name), + materialized_path=materialized_path, + page_number=page, + snippet=_snippet(body), + tag=f"[chunk:{chunk_id}]", + ) diff --git a/src/ks_mcp/tools/org.py b/src/ks_mcp/tools/org.py index cec68ab..b59125e 100644 --- a/src/ks_mcp/tools/org.py +++ b/src/ks_mcp/tools/org.py @@ -1,6 +1,5 @@ """Tenant-context tools: organisation info + current datetime.""" - import os from datetime import UTC, datetime from zoneinfo import ZoneInfo @@ -75,12 +74,22 @@ def _fetch_tenant() -> OrganizationInfo: def register(mcp: FastMCP) -> None: @mcp.tool() def get_organization_info() -> OrganizationInfo: - """Return the caller's tenant metadata: id, name, default language, timezone.""" + """Return the caller's tenant metadata: id, name, default language, timezone. + + Cached per process. Useful as the first call when the user's question + depends on locale ("translate this to our default language") or the + tenant's name ("write an internal email at "). + """ return _fetch_tenant() @mcp.tool() def get_current_datetime() -> CurrentDateTime: - """Return current date/time in both UTC and the tenant's timezone.""" + """Return current date/time in both UTC and the tenant's timezone. + + Use whenever the user's query is relative ("yesterday's notes", "this + quarter's results"). The tool consults ``get_organization_info`` for + the tenant's IANA timezone, falling back to UTC if unset. + """ info = _fetch_tenant() now_utc = datetime.now(UTC) try: diff --git a/src/ks_mcp/tools/provenance.py b/src/ks_mcp/tools/provenance.py index 3177846..363c835 100644 --- a/src/ks_mcp/tools/provenance.py +++ b/src/ks_mcp/tools/provenance.py @@ -5,7 +5,6 @@ changed between two versions — not just retrieve text. """ - import difflib from typing import Annotated, Any from uuid import UUID @@ -49,14 +48,18 @@ def _flatten_version_text(client: ksapi.ApiClient, version_id: UUID, limit: int offset = 0 while True: contents: Any = api.get_document_version_contents( - version_id=version_id, limit=100, offset=offset, + version_id=version_id, + limit=100, + offset=offset, ) items = getattr(contents, "items", None) or [] if not items: break for item in items: inner = getattr(item, "actual_instance", None) or item - text = getattr(inner, "text", None) or getattr(inner, "content", "") or "" + # ksapi's ChunkContentItem stores body on ``.content``; tolerate ``.text`` + # for older SDK builds. + text = getattr(inner, "content", None) or getattr(inner, "text", "") or "" if text: lines.append(text.strip()) if len(items) < 100 or len(lines) >= limit: @@ -68,13 +71,17 @@ def _flatten_version_text(client: ksapi.ApiClient, version_id: UUID, limit: int def register(mcp: FastMCP) -> None: @mcp.tool() def trace_chunk_lineage( - chunk_id: Annotated[UUID, Field(description="The chunk whose lineage you want to inspect.")], + chunk_id: Annotated[ + UUID, Field(description="The chunk whose lineage you want to inspect.") + ], ) -> LineageResult: - """Return the lineage graph for a chunk. + """Return the lineage graph for a chunk (merge/split/re-embed history). - KS tracks how chunks are derived (merge / split / re-embed / re-ingest) so - that agents can explain *why* a piece of evidence exists. Use this when a - downstream answer cites a chunk and you need to justify provenance. + KS tracks how chunks are derived (merge / split / re-embed / + re-ingest) so that agents can explain *why* a piece of evidence + exists. Use this when a downstream answer cites a chunk and you need + to justify provenance — for example to answer "did this fact survive + from v3, or did it appear in v4?". """ client = get_api_client() api = ChunkLineagesApi(client) @@ -102,16 +109,23 @@ def trace_chunk_lineage( @mcp.tool() def compare_versions( - document_id: Annotated[UUID, Field(description="Document whose versions you want to diff.")], + document_id: Annotated[ + UUID, Field(description="Document whose versions you want to diff.") + ], from_version_id: Annotated[UUID, Field(description="Older / baseline version id.")], to_version_id: Annotated[UUID, Field(description="Newer / target version id.")], - max_chunks_per_side: Annotated[int, Field(description="Cap per-version chunks loaded for the diff.", ge=10, le=2000)] = 500, + max_chunks_per_side: Annotated[ + int, Field(description="Cap per-version chunks loaded for the diff.", ge=10, le=2000) + ] = 500, ) -> VersionDiffResult: """Produce a unified text diff between two versions of the same document. - Client-side line diff over each version's flattened chunk text — enough - for agents to answer "what changed in v5 vs v4?" without loading both - versions wholesale into the prompt. + Client-side line diff over each version's flattened chunk text — + enough for agents to answer "what changed in v5 vs v4?" without + loading both versions wholesale into the prompt. The diff is capped + to ``max_chunks_per_side`` chunks per version; for very long + documents, scope the question with ``parent_path_part_ids`` on a + prior ``search_*`` call instead. """ client = get_api_client() try: @@ -122,7 +136,8 @@ def compare_versions( diff_lines = list( difflib.unified_diff( - before, after, + before, + after, fromfile=f"v:{from_version_id}", tofile=f"v:{to_version_id}", lineterm="", @@ -130,7 +145,9 @@ def compare_versions( ) ) added = sum(1 for line in diff_lines if line.startswith("+") and not line.startswith("+++")) - removed = sum(1 for line in diff_lines if line.startswith("-") and not line.startswith("---")) + removed = sum( + 1 for line in diff_lines if line.startswith("-") and not line.startswith("---") + ) return VersionDiffResult( document_id=document_id, from_version_id=from_version_id, diff --git a/src/ks_mcp/tools/read.py b/src/ks_mcp/tools/read.py index 63ad163..479e6e7 100644 --- a/src/ks_mcp/tools/read.py +++ b/src/ks_mcp/tools/read.py @@ -1,6 +1,5 @@ """Read a PDO (document / section / chunk), read neighbours, view image chunks.""" - import base64 from typing import Annotated, Any from uuid import UUID @@ -16,7 +15,7 @@ from pydantic import Field from ks_mcp.client import get_api_client -from ks_mcp.errors import rest_to_mcp +from ks_mcp.errors import is_not_found, rest_to_mcp def _truncate(text: str, limit: int) -> str: @@ -25,39 +24,90 @@ def _truncate(text: str, limit: int) -> str: return text[:limit] + f"\n...[truncated; {len(text) - limit} more chars]" +def _normalize_part_type(value: Any) -> str: + """Coerce a ksapi PartType (or string) into a bare uppercase token.""" + text = str(value or "") + if text.startswith("PartType."): + text = text.removeprefix("PartType.") + return text + + +def _read_chunk_body(client: ksapi.ApiClient, chunk_id: UUID, max_chars: int) -> str: + chunk: Any = ChunksApi(client).get_chunk(chunk_id=chunk_id) + # ksapi model exposes the text as .content on Chunk; older builds used .text. + body = getattr(chunk, "content", None) or getattr(chunk, "text", "") or "" + return _truncate(f"{body}\n\n[chunk:{chunk_id}]", max_chars) + + def register(mcp: FastMCP) -> None: @mcp.tool() def read( - path_part_id: Annotated[UUID, Field(description="Any PDO id (folder, document, section, or chunk).")], - max_chars: Annotated[int, Field(description="Truncate returned text to this many characters.", ge=100, le=50_000)] = 4000, + path_part_id: Annotated[ + UUID, + Field( + description=( + "Any PDO id (folder, document, section, or chunk path-part) — OR a " + "raw chunk_id from a search hit. The tool first tries to resolve it " + "as a path-part; on 404 it falls back to fetching as a chunk." + ), + ), + ], + max_chars: Annotated[ + int, + Field(description="Truncate returned text to this many characters.", ge=100, le=50_000), + ] = 4000, ) -> str: - """Read the contents of any PDO. Dispatches on part type. + """Read the contents of any PDO and return Markdown text. - For documents/sections the structural outline is returned; for chunks - the raw text. Use after ``find`` or ``search_*`` when you want full text. + Dispatch: + + * **CHUNK** → raw chunk text + a ``[chunk:UUID]`` citation tag. + * **SECTION** → section name + page number; use ``read`` on the parent + DOCUMENT for full text. + * **DOCUMENT** → flattened, ordered chunks with section headings, + paginated up to ``max_chars``. + * **FOLDER / unknown** → name + a hint to use ``list_contents`` to drill in. + + If ``path_part_id`` 404s as a path-part, ``read`` retries the lookup as + a chunk, so callers can pass either a ``path_part_id`` or a ``chunk_id`` + from a search result without a second round-trip. Pair with ``cite`` + when you also need a citation footer. """ client = get_api_client() path_parts = PathPartsApi(client) try: pp = path_parts.get_path_part(path_part_id=path_part_id) except ksapi.ApiException as exc: + # Agents often pass a chunk_id straight from a search hit. Treat 404 + # as "maybe it's a chunk id" and retry rather than dead-ending. + if is_not_found(exc): + try: + return _read_chunk_body(client, path_part_id, max_chars) + except ksapi.ApiException as inner: + raise rest_to_mcp(inner) from inner raise rest_to_mcp(exc) from exc - part_type = getattr(pp, "part_type", "") + part_type = _normalize_part_type(getattr(pp, "part_type", "")) metadata_obj_id = getattr(pp, "metadata_obj_id", None) try: if part_type == "CHUNK" and metadata_obj_id: - chunk: Any = ChunksApi(client).get_chunk(chunk_id=metadata_obj_id) - return _truncate(getattr(chunk, "text", "") or "", max_chars) + return _read_chunk_body(client, metadata_obj_id, max_chars) if part_type == "SECTION" and metadata_obj_id: section: Any = SectionsApi(client).get_section(section_id=metadata_obj_id) - body = getattr(section, "text", None) or getattr(section, "title", "") - return _truncate(body or f"(section has no body: {pp.name})", max_chars) + page = getattr(section, "page_number", None) + page_suffix = f" (page {page})" if page is not None else "" + header = f"# {getattr(section, 'name', None) or pp.name}{page_suffix}" + hint = ( + "(SECTION has no inline body — read its parent DOCUMENT for full text, " + "or use list_contents on this path-part to walk children.)" + ) + return _truncate(f"{header}\n\n{hint}", max_chars) if part_type == "DOCUMENT" and metadata_obj_id: from ksapi.api.document_versions_api import DocumentVersionsApi + doc: Any = DocumentsApi(client).get_document(document_id=metadata_obj_id) version_id = getattr(doc, "active_version_id", None) if version_id is None: @@ -66,20 +116,27 @@ def read( offset = 0 while True: contents: Any = DocumentVersionsApi(client).get_document_version_contents( - version_id=version_id, limit=100, offset=offset, + version_id=version_id, + limit=100, + offset=offset, ) items = getattr(contents, "items", None) or [] if not items: break for item in items: inner = getattr(item, "actual_instance", None) or item - ptype = str(getattr(inner, "part_type", "")) or type(inner).__name__ + ptype = _normalize_part_type(getattr(inner, "part_type", "")) name = getattr(inner, "name", "") - if "SECTION" in ptype or "Section" in type(inner).__name__: + if ptype == "SECTION": pieces.append(f"\n## {name}\n") else: - text = getattr(inner, "text", None) or getattr(inner, "content", "") or "" - chunk_id = getattr(inner, "id", None) or getattr(inner, "metadata_obj_id", None) + # ChunkContentItem stores body on `.content`; tolerate `.text`. + text = ( + getattr(inner, "content", None) or getattr(inner, "text", "") or "" + ) + chunk_id = getattr(inner, "metadata_obj_id", None) or getattr( + inner, "id", None + ) tag = f" [chunk:{chunk_id}]" if chunk_id else "" if text: pieces.append(f"{text}{tag}\n") @@ -88,30 +145,57 @@ def read( offset += 100 return _truncate("".join(pieces), max_chars) - return f"{pp.name} ({part_type}) — use list_contents to drill in." + return f"{pp.name} ({part_type or 'UNKNOWN'}) — use list_contents to drill in." except ksapi.ApiException as exc: raise rest_to_mcp(exc) from exc @mcp.tool() def read_around( - chunk_id: Annotated[UUID, Field(description="Anchor chunk id.")], - radius: Annotated[int, Field(description="How many chunks before AND after the anchor to include.", ge=0, le=10)] = 2, + chunk_id: Annotated[UUID, Field(description="Anchor chunk id (NOT a path_part_id).")], + radius: Annotated[ + int, + Field( + description="How many chunks before AND after the anchor to include.", ge=0, le=10 + ), + ] = 2, ) -> str: - """Return the ``radius`` chunks before and after an anchor chunk, concatenated. + """Return the ``radius`` chunks before and after an anchor chunk. + + Great for pulling enough local context when a single chunk isn't + enough — e.g. the answer hinges on a sentence that references "the + table above". - Great for pulling enough local context when a single chunk isn't enough. + Output is ordered preceding → anchor → succeeding. Each neighbour is + labelled (``[ANCHOR]`` or ``[ctx N]``) and tagged with its + ``[chunk:UUID]`` so the agent can cite the right neighbour, not just + the anchor. """ api = ChunksApi(get_api_client()) try: - neighbours: Any = api.get_chunk_neighbors(chunk_id=chunk_id, radius=radius) + # ksapi exposes `prev` and `next` separately; `radius` is the symmetric + # convenience the MCP tool surfaces to keep the agent UX simple. + neighbours: Any = api.get_chunk_neighbors( + chunk_id=chunk_id, + prev=radius, + next=radius, + chunks_only=True, + ) except ksapi.ApiException as exc: raise rest_to_mcp(exc) from exc + items = getattr(neighbours, "items", []) or [] + anchor_index = getattr(neighbours, "anchor_index", -1) pieces: list[str] = [] - for item in getattr(neighbours, "items", []) or []: - label = "ANCHOR" if getattr(item, "is_anchor", False) else "..." - text = getattr(item, "text", "") or "" - pieces.append(f"[{label}] {text}") + for idx, raw in enumerate(items): + # Items are SectionContentItemOrChunkContentItem — unwrap the union. + # `chunks_only=True` upstream means non-chunk neighbours are already + # filtered out, so we just render whatever comes back. + inner = getattr(raw, "actual_instance", None) or raw + text = getattr(inner, "content", None) or getattr(inner, "text", "") or "" + inner_chunk_id = getattr(inner, "metadata_obj_id", None) or getattr(inner, "id", None) + label = "ANCHOR" if idx == anchor_index else f"ctx {idx - anchor_index:+d}" + tag = f" [chunk:{inner_chunk_id}]" if inner_chunk_id else "" + pieces.append(f"[{label}]{tag}\n{text}") return "\n\n".join(pieces) or "(no neighbours returned)" @mcp.tool() @@ -120,8 +204,11 @@ def view_chunk_image( ) -> ImageContent: """Fetch the image bytes for an IMAGE-type chunk and return them to the agent. - Only works for chunks whose metadata carries at least one S3 URL. - Agent frameworks that support multi-modal content will render inline. + Only works for chunks whose metadata carries at least one S3 URL + (typically ``chunk_type == "IMAGE"``). Agent frameworks that support + multi-modal content (Claude, GPT-4o, Gemini) render the result inline; + text-only frameworks should expect this tool to error and call + ``read``/``cite`` for a textual surrogate instead. """ api = ChunksApi(get_api_client()) try: diff --git a/src/ks_mcp/tools/search.py b/src/ks_mcp/tools/search.py index 1ec822b..ff9f751 100644 --- a/src/ks_mcp/tools/search.py +++ b/src/ks_mcp/tools/search.py @@ -1,11 +1,11 @@ """Semantic + keyword search over the tenant's knowledge base.""" - from typing import Annotated, Any from uuid import UUID import ksapi from ksapi.api.chunks_api import ChunksApi +from ksapi.models.search_type import SearchType from mcp.server.fastmcp import FastMCP from pydantic import Field @@ -19,63 +19,97 @@ def _build_search_request( top_k: int, parent_path_part_ids: list[UUID] | None, tag_ids: list[UUID] | None, - distinct_files: bool, - search_type: str, -) -> Any: + search_type: SearchType, +) -> ksapi.ChunkSearchRequest: + # The backend field is `parent_path_ids` (UUIDs of path-part scopes). + # The MCP-facing param keeps the longer name (`parent_path_part_ids`) for + # clarity since "path_part_id" is the canonical id concept on the API. + # ``with_document=True`` is required to populate ``chunk.document.name`` — + # without it every hit's document name comes back empty. return ksapi.ChunkSearchRequest( query=query, top_k=top_k, search_type=search_type, - parent_path_part_ids=[str(p) for p in parent_path_part_ids] - if parent_path_part_ids else None, - tag_ids=[str(t) for t in tag_ids] if tag_ids else None, - distinct_files=distinct_files, + parent_path_ids=list(parent_path_part_ids) if parent_path_part_ids else None, + tag_ids=list(tag_ids) if tag_ids else None, + with_document=True, ) def _hit_from_scored_chunk(scored: Any) -> ChunkHit: + # ksapi's ScoredChunkResponse is flat: chunk fields and `score` live side-by-side + # at the top level. The defensive `scored.chunk` fallback keeps backwards-compat + # with older SDK builds where the chunk was nested. chunk = getattr(scored, "chunk", scored) + document = getattr(chunk, "document", None) + document_name = getattr(document, "name", None) or getattr(chunk, "document_name", None) or "" + # The chunk text lives on `.content` in current ksapi; older builds used `.text`. + body = getattr(chunk, "content", None) or getattr(chunk, "text", "") or "" + raw_chunk_type: Any = getattr(chunk, "chunk_type", ChunkType.TEXT.value) + chunk_type_value = getattr(raw_chunk_type, "value", None) or str(raw_chunk_type) + try: + chunk_type = ChunkType(chunk_type_value) + except ValueError: + chunk_type = ChunkType.UNKNOWN return ChunkHit( chunk_id=chunk.id, - document_name=getattr(chunk, "document_name", None) - or (chunk.document.name if getattr(chunk, "document", None) else "") - or "", - text=(getattr(chunk, "text", "") or ""), + document_name=document_name, + text=body, score=getattr(scored, "score", None), - chunk_type=ChunkType(getattr(chunk, "chunk_type", ChunkType.TEXT.value)), + chunk_type=chunk_type, path_part_id=getattr(chunk, "path_part_id", None), + materialized_path=getattr(chunk, "materialized_path", None), ) def register(mcp: FastMCP) -> None: @mcp.tool() def search_knowledge( - query: Annotated[str, Field(description="Natural-language query for semantic retrieval.", min_length=1, max_length=4000)], - top_k: Annotated[int, Field(description="Max number of chunks to return (1-50).", ge=1, le=50)] = 5, + query: Annotated[ + str, + Field( + description="Natural-language query for semantic retrieval.", + min_length=1, + max_length=4000, + ), + ], + top_k: Annotated[ + int, Field(description="Max number of chunks to return (1-50).", ge=1, le=50) + ] = 5, parent_path_part_ids: Annotated[ list[UUID] | None, - Field(description="Restrict search to descendants of these path-parts. Omit for whole tenant."), + Field( + description="Restrict search to descendants of these path-parts. Omit for whole tenant." + ), ] = None, tag_ids: Annotated[ list[UUID] | None, Field(description="Only include chunks whose document carries all these tag UUIDs."), ] = None, - distinct_files: Annotated[ - bool, - Field(description="If true, at most one chunk per source document is returned."), - ] = False, ) -> SearchResult: """Semantic (dense-vector) search over the tenant's chunks. - Use for conceptual questions: returns passages semantically related - to the query with a relevance score. For exact-term lookups prefer - ``search_keyword`` instead. + Use for conceptual questions ("how does X work", "anything about Y"). + Returns passages semantically related to the query with a relevance + score (higher = better). For exact-term lookups, prefer + ``search_keyword``. + + Each hit carries a ``chunk_id``, ``materialized_path`` and a ``text`` + snippet. Recommended follow-ups: + + * ``cite(chunk_id)`` — get a structured citation for the answer. + * ``read_around(chunk_id, radius=2)`` — pull surrounding context. + * ``read(path_part_id=chunk_id)`` — fetch the full chunk body. """ api = ChunksApi(get_api_client()) try: response = api.search_chunks( chunk_search_request=_build_search_request( - query, top_k, parent_path_part_ids, tag_ids, distinct_files, "dense_only" + query, + top_k, + parent_path_part_ids, + tag_ids, + SearchType.DENSE_ONLY, ) ) except ksapi.ApiException as exc: @@ -85,21 +119,43 @@ def search_knowledge( @mcp.tool() def search_keyword( - query: Annotated[str, Field(description="Keyword or phrase to match (BM25 full-text).", min_length=1, max_length=4000)], - top_k: Annotated[int, Field(description="Max number of chunks to return (1-50).", ge=1, le=50)] = 5, - parent_path_part_ids: Annotated[list[UUID] | None, Field(description="Restrict to descendants of these path-parts.")] = None, - tag_ids: Annotated[list[UUID] | None, Field(description="Only chunks whose document carries all these tags.")] = None, - distinct_files: Annotated[bool, Field(description="One chunk per source document when true.")] = False, + query: Annotated[ + str, + Field( + description="Keyword or phrase to match (BM25 full-text).", + min_length=1, + max_length=4000, + ), + ], + top_k: Annotated[ + int, Field(description="Max number of chunks to return (1-50).", ge=1, le=50) + ] = 5, + parent_path_part_ids: Annotated[ + list[UUID] | None, Field(description="Restrict to descendants of these path-parts.") + ] = None, + tag_ids: Annotated[ + list[UUID] | None, + Field(description="Only chunks whose document carries all these tags."), + ] = None, ) -> SearchResult: """BM25 / keyword search over the tenant's chunks. - Use when the user mentions a specific term, name, or quoted phrase. + Use when the user mentions a specific term, name, identifier, or + quoted phrase that needs an exact (or near-exact) match. For + conceptual queries, prefer ``search_knowledge``. + + Each hit carries a ``chunk_id`` and ``materialized_path``. Pair with + ``cite``/``read``/``read_around`` to ground an answer. """ api = ChunksApi(get_api_client()) try: response = api.search_chunks( chunk_search_request=_build_search_request( - query, top_k, parent_path_part_ids, tag_ids, distinct_files, "full_text" + query, + top_k, + parent_path_part_ids, + tag_ids, + SearchType.FULL_TEXT, ) ) except ksapi.ApiException as exc: diff --git a/tests/conftest.py b/tests/conftest.py index b219097..02f3daf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ - import pytest diff --git a/tests/test_ask.py b/tests/test_ask.py new file mode 100644 index 0000000..578287b --- /dev/null +++ b/tests/test_ask.py @@ -0,0 +1,105 @@ +"""Unit tests for the ``ask`` tool's SSE parser and result builder.""" + +from uuid import UUID + +from ks_mcp.schema import AskResult +from ks_mcp.tools.ask import ( + _build_result, + _parse_sse_block, + _to_ask_citation, +) + +_THREAD_ID = UUID("00000000-0000-0000-0000-000000000010") +_MSG_ID = UUID("00000000-0000-0000-0000-000000000020") +_CHUNK_ID = UUID("00000000-0000-0000-0000-000000000030") + + +def test_parse_sse_block_event_and_data() -> None: + event, data = _parse_sse_block('event: text_delta\ndata: {"delta":"hi"}') + assert event == "text_delta" + assert data == '{"delta":"hi"}' + + +def test_parse_sse_block_terminal_done() -> None: + event, data = _parse_sse_block("data: [DONE]") + assert event is None + assert data == "[DONE]" + + +def test_parse_sse_block_multiline_data_is_joined() -> None: + event, data = _parse_sse_block("event: x\ndata: line1\ndata: line2") + assert event == "x" + assert data == "line1\nline2" + + +def test_parse_sse_block_ignores_comments_and_garbage() -> None: + event, data = _parse_sse_block(": keepalive\n\n") + assert event is None + assert data == "" + + # ``id:`` and unknown fields silently dropped. + event, data = _parse_sse_block("id: 42\nevent: text_delta\ndata: payload") + assert event == "text_delta" + assert data == "payload" + + +def test_parse_sse_block_handles_field_without_colon() -> None: + # SSE technically allows "field" without colon — we drop it rather than crash. + event, data = _parse_sse_block("event: x\ndata: ok\nweird-line") + assert event == "x" + assert data == "ok" + + +def test_to_ask_citation_full_payload() -> None: + raw = { + "chunk_id": str(_CHUNK_ID), + "quote": "Q", + "document_id": "00000000-0000-0000-0000-000000000040", + "document_name": "Doc", + "materialized_path": "/a/b/c", + "page_number": 7, + } + citation = _to_ask_citation(raw) + assert citation.chunk_id == _CHUNK_ID + assert citation.quote == "Q" + assert str(citation.document_id) == "00000000-0000-0000-0000-000000000040" + assert citation.document_name == "Doc" + assert citation.materialized_path == "/a/b/c" + assert citation.page_number == 7 + + +def test_to_ask_citation_minimal_payload() -> None: + citation = _to_ask_citation({"chunk_id": str(_CHUNK_ID), "quote": ""}) + assert citation.chunk_id == _CHUNK_ID + assert citation.quote == "" + assert citation.document_id is None + assert citation.document_name is None + assert citation.materialized_path is None + assert citation.page_number is None + + +def test_build_result_concatenates_text_parts() -> None: + result: AskResult = _build_result(["Hello, ", "world!"], [], _THREAD_ID, _MSG_ID, False, "") + assert result.answer == "Hello, world!" + assert result.is_error is False + assert result.thread_id == _THREAD_ID + assert result.message_id == _MSG_ID + + +def test_build_result_surfaces_error_text_when_no_partial_answer() -> None: + result = _build_result([], [], _THREAD_ID, None, True, "rate limited") + assert result.is_error is True + assert "rate limited" in result.answer + + +def test_build_result_keeps_partial_answer_on_error() -> None: + result = _build_result(["partial"], [], _THREAD_ID, None, True, "blew up") + # We keep whatever the agent already streamed rather than overwriting with the error. + assert result.answer == "partial" + assert result.is_error is True + + +def test_build_result_empty_stream_returns_placeholder() -> None: + result = _build_result([], [], _THREAD_ID, None, False, "") + assert result.answer == "(empty answer)" + assert result.is_error is False diff --git a/tests/test_cite.py b/tests/test_cite.py new file mode 100644 index 0000000..e124b0a --- /dev/null +++ b/tests/test_cite.py @@ -0,0 +1,140 @@ +"""Unit tests for ``cite`` tool helpers and the page-number ancestry walk.""" + +from types import SimpleNamespace +from typing import Any, cast +from uuid import UUID + +import ksapi + +from ks_mcp.tools.cite import _normalize_part_type, _page_number_from_ancestry, _snippet + +_PATH_PART_ID = UUID("00000000-0000-0000-0000-000000000001") +_SECTION_METADATA_ID = UUID("00000000-0000-0000-0000-000000000002") + + +def test_snippet_returns_short_text_unchanged() -> None: + assert _snippet("hello") == "hello" + + +def test_snippet_truncates_long_text_at_word_boundary() -> None: + text = "alpha beta gamma " * 50 + out = _snippet(text, limit=40) + assert out.endswith("…") + # Truncation must respect the cap — allow some slack for the ellipsis. + assert len(out) <= 41 + # Word-boundary cut: should not slice mid-word. + assert " " in out + assert not out.removesuffix("…").endswith("alph") + + +def test_snippet_falls_back_to_hard_cut_when_no_late_space() -> None: + text = "x" * 300 + out = _snippet(text, limit=50) + assert out == "x" * 50 + "…" + + +def test_snippet_handles_empty_input() -> None: + assert _snippet("") == "" + assert _snippet(" ") == "" + + +def test_normalize_part_type_strips_enum_prefix() -> None: + assert _normalize_part_type("PartType.SECTION") == "SECTION" + assert _normalize_part_type("SECTION") == "SECTION" + assert _normalize_part_type("") == "" + assert _normalize_part_type(None) == "" + + +class _FakeAncestor: + """Minimal stand-in for an ancestor item the SDK might return.""" + + def __init__(self, part_type: str, metadata_obj_id: UUID | None = None) -> None: + self.part_type = part_type + self.metadata_obj_id = metadata_obj_id + + +class _FakeAncestryResp: + def __init__(self, ancestors: list[_FakeAncestor]) -> None: + self.ancestors = ancestors + + +class _FakeSection: + def __init__(self, page_number: int | None) -> None: + self.page_number = page_number + + +class _StubPathPartsApi: + def __init__(self, ancestry: Any | Exception) -> None: + self._ancestry = ancestry + + def get_path_part_ancestry(self, path_part_id: UUID) -> Any: + if isinstance(self._ancestry, Exception): + raise self._ancestry + return self._ancestry + + +class _StubSectionsApi: + def __init__(self, section: Any | Exception) -> None: + self._section = section + + def get_section(self, section_id: UUID) -> Any: + if isinstance(self._section, Exception): + raise self._section + return self._section + + +def test_page_number_from_ancestry_returns_none_when_no_path_part() -> None: + client = cast(ksapi.ApiClient, _FakeClient(None, None)) + assert _page_number_from_ancestry(client, None) is None + + +def test_page_number_from_ancestry_picks_deepest_section(monkeypatch: Any) -> None: + ancestors = [ + _FakeAncestor("FOLDER"), + _FakeAncestor("DOCUMENT"), + _FakeAncestor("SECTION", UUID("00000000-0000-0000-0000-0000000000aa")), + _FakeAncestor("SECTION", _SECTION_METADATA_ID), + ] + section = _FakeSection(page_number=42) + client = cast(ksapi.ApiClient, _FakeClient(_FakeAncestryResp(ancestors), section)) + + import ks_mcp.tools.cite as cite_module + + monkeypatch.setattr(cite_module, "PathPartsApi", lambda c: _StubPathPartsApi(c.ancestry)) + monkeypatch.setattr(cite_module, "SectionsApi", lambda c: _StubSectionsApi(c.section)) + + page = _page_number_from_ancestry(client, _PATH_PART_ID) + assert page == 42 + + +def test_page_number_from_ancestry_returns_none_when_no_section(monkeypatch: Any) -> None: + ancestors = [_FakeAncestor("FOLDER"), _FakeAncestor("DOCUMENT")] + client = cast(ksapi.ApiClient, _FakeClient(_FakeAncestryResp(ancestors), None)) + + import ks_mcp.tools.cite as cite_module + + monkeypatch.setattr(cite_module, "PathPartsApi", lambda c: _StubPathPartsApi(c.ancestry)) + monkeypatch.setattr(cite_module, "SectionsApi", lambda c: _StubSectionsApi(c.section)) + + assert _page_number_from_ancestry(client, _PATH_PART_ID) is None + + +def test_page_number_from_ancestry_swallows_api_errors(monkeypatch: Any) -> None: + err = ksapi.ApiException(status=500, reason="boom") + client = cast(ksapi.ApiClient, _FakeClient(err, None)) + + import ks_mcp.tools.cite as cite_module + + monkeypatch.setattr(cite_module, "PathPartsApi", lambda c: _StubPathPartsApi(c.ancestry)) + monkeypatch.setattr(cite_module, "SectionsApi", lambda c: _StubSectionsApi(c.section)) + + # Errors fetching ancestry / section should degrade to None, not propagate — + # the caller treats page_number as best-effort metadata. + assert _page_number_from_ancestry(client, _PATH_PART_ID) is None + + +class _FakeClient(SimpleNamespace): + """Carries the ancestry / section stubs into the patched API constructors.""" + + def __init__(self, ancestry: Any, section: Any) -> None: + super().__init__(ancestry=ancestry, section=section) diff --git a/tests/test_client.py b/tests/test_client.py index b4e5018..28e5a53 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,3 @@ - import pytest diff --git a/tests/test_errors.py b/tests/test_errors.py index 52c10ba..c70f03b 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,4 +1,3 @@ - import ksapi from mcp.types import INTERNAL_ERROR, INVALID_PARAMS diff --git a/tests/test_read.py b/tests/test_read.py new file mode 100644 index 0000000..e187cc3 --- /dev/null +++ b/tests/test_read.py @@ -0,0 +1,30 @@ +"""Unit tests for read.py helpers: text truncation, part-type normalization.""" + +from ks_mcp.tools.read import _normalize_part_type, _truncate + + +def test_truncate_under_limit_returns_input() -> None: + assert _truncate("short", 100) == "short" + + +def test_truncate_over_limit_marks_remaining_chars() -> None: + text = "x" * 150 + out = _truncate(text, 100) + assert out.startswith("x" * 100) + assert "truncated; 50 more chars" in out + + +def test_truncate_at_exact_limit_is_unchanged() -> None: + text = "x" * 50 + assert _truncate(text, 50) == text + + +def test_normalize_part_type_strips_enum_prefix() -> None: + assert _normalize_part_type("PartType.CHUNK") == "CHUNK" + assert _normalize_part_type("PartType.SECTION") == "SECTION" + assert _normalize_part_type("FOLDER") == "FOLDER" + + +def test_normalize_part_type_handles_empty_input() -> None: + assert _normalize_part_type("") == "" + assert _normalize_part_type(None) == "" diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..8312ec8 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,95 @@ +"""Schema round-trip tests for the public output models.""" + +from uuid import UUID + +import pytest +from pydantic import ValidationError + +from ks_mcp.schema import AskCitation, AskResult, ChunkHit, ChunkType, Citation + +_CHUNK_ID = UUID("00000000-0000-0000-0000-000000000001") +_PATH_PART_ID = UUID("00000000-0000-0000-0000-000000000002") +_THREAD_ID = UUID("00000000-0000-0000-0000-000000000003") +_DOC_ID = UUID("00000000-0000-0000-0000-000000000004") + + +def test_chunk_hit_carries_materialized_path() -> None: + hit = ChunkHit( + chunk_id=_CHUNK_ID, + document_name="Doc", + text="body", + score=0.5, + chunk_type=ChunkType.TEXT, + path_part_id=_PATH_PART_ID, + materialized_path="A/B/C", + ) + dumped = hit.model_dump() + assert dumped["materialized_path"] == "A/B/C" + assert dumped["chunk_type"] == ChunkType.TEXT + + +def test_chunk_hit_defaults() -> None: + hit = ChunkHit(chunk_id=_CHUNK_ID, document_name="", text="") + assert hit.score is None + assert hit.path_part_id is None + assert hit.materialized_path is None + assert hit.chunk_type == ChunkType.TEXT + + +def test_citation_serializes_to_expected_keys() -> None: + cit = Citation( + chunk_id=_CHUNK_ID, + document_name="Handbook", + materialized_path="HR/Onboarding/Section 2", + page_number=4, + snippet="hello", + tag=f"[chunk:{_CHUNK_ID}]", + ) + dumped = cit.model_dump() + assert set(dumped) == { + "chunk_id", + "document_name", + "materialized_path", + "page_number", + "snippet", + "tag", + } + assert dumped["tag"] == f"[chunk:{_CHUNK_ID}]" + + +def test_ask_result_default_citations_is_empty_list() -> None: + result = AskResult(answer="ok", thread_id=_THREAD_ID) + assert result.citations == [] + assert result.is_error is False + assert result.message_id is None + assert result.workflow_id is None + + +def test_ask_result_with_citations_round_trip() -> None: + citation = AskCitation( + chunk_id=_CHUNK_ID, + quote="the quoted span", + document_id=_DOC_ID, + document_name="Doc", + materialized_path="A/B", + page_number=2, + ) + result = AskResult( + answer="grounded answer", + citations=[citation], + thread_id=_THREAD_ID, + message_id=UUID("00000000-0000-0000-0000-000000000099"), + workflow_id="wf-123", + ) + dumped = result.model_dump() + assert dumped["answer"] == "grounded answer" + assert dumped["workflow_id"] == "wf-123" + assert dumped["citations"][0]["quote"] == "the quoted span" + assert dumped["citations"][0]["document_id"] == _DOC_ID + + +def test_ask_result_requires_thread_id() -> None: + with pytest.raises(ValidationError): + # Intentionally constructed without thread_id to confirm the model + # rejects it; pyright correctly flags the missing kwarg. + AskResult.model_validate({"answer": "x"}) diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 0000000..2c7a2f4 --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,108 @@ +"""Unit tests for the search tool: request shape + scored hit projection.""" + +from types import SimpleNamespace +from uuid import UUID + +import pytest +from ksapi.models.search_type import SearchType + +from ks_mcp.schema import ChunkType +from ks_mcp.tools.search import _build_search_request, _hit_from_scored_chunk + +_PARENT_ID = UUID("00000000-0000-0000-0000-0000000000aa") +_TAG_ID = UUID("00000000-0000-0000-0000-0000000000bb") +_CHUNK_ID = UUID("00000000-0000-0000-0000-0000000000cc") +_PATH_PART_ID = UUID("00000000-0000-0000-0000-0000000000dd") + + +def test_build_search_request_minimal() -> None: + req = _build_search_request("hello", 10, None, None, SearchType.DENSE_ONLY) + body = req.to_dict() + # Renamed field reaches the wire (was `parent_path_part_ids`). + assert "parent_path_part_ids" not in body + assert body.get("parent_path_ids") is None + # `with_document=True` is required for hits to carry document_name. + assert body.get("with_document") is True + assert body["query"] == "hello" + assert body["top_k"] == 10 + # `distinct_files` is gone — the backend dropped it. + assert "distinct_files" not in body + + +def test_build_search_request_with_filters() -> None: + req = _build_search_request( + "term", + 5, + [_PARENT_ID], + [_TAG_ID], + SearchType.FULL_TEXT, + ) + body = req.to_dict() + assert body["parent_path_ids"] == [_PARENT_ID] + assert body["tag_ids"] == [_TAG_ID] + assert body["search_type"] == SearchType.FULL_TEXT + + +def test_hit_from_scored_chunk_full_payload() -> None: + document = SimpleNamespace(name="The Onboarding Handbook") + scored = SimpleNamespace( + id=_CHUNK_ID, + path_part_id=_PATH_PART_ID, + content="The onboarding flow has three steps.", + chunk_type=ChunkType.TEXT.value, + score=0.87, + materialized_path="HR/Handbooks/Onboarding/...", + document=document, + ) + hit = _hit_from_scored_chunk(scored) + assert hit.chunk_id == _CHUNK_ID + assert hit.document_name == "The Onboarding Handbook" + assert hit.text == "The onboarding flow has three steps." + assert hit.score == pytest.approx(0.87) + assert hit.chunk_type == ChunkType.TEXT + assert hit.path_part_id == _PATH_PART_ID + assert hit.materialized_path == "HR/Handbooks/Onboarding/..." + + +def test_hit_from_scored_chunk_falls_back_to_text_field() -> None: + # Older SDK builds exposed body on `.text`; we tolerate either. + scored = SimpleNamespace( + id=_CHUNK_ID, + path_part_id=None, + text="legacy body", + chunk_type="TEXT", + score=None, + document=None, + ) + hit = _hit_from_scored_chunk(scored) + assert hit.text == "legacy body" + assert hit.document_name == "" + assert hit.path_part_id is None + assert hit.materialized_path is None + + +def test_hit_from_scored_chunk_unknown_chunk_type_degrades_to_unknown() -> None: + scored = SimpleNamespace( + id=_CHUNK_ID, + path_part_id=None, + content="x", + chunk_type="WAT", + score=0.1, + document=None, + ) + hit = _hit_from_scored_chunk(scored) + assert hit.chunk_type == ChunkType.UNKNOWN + + +def test_hit_from_scored_chunk_handles_enum_chunk_type() -> None: + # ksapi sometimes hands back an Enum, sometimes a bare string. Both must work. + scored = SimpleNamespace( + id=_CHUNK_ID, + path_part_id=None, + content="x", + chunk_type=SimpleNamespace(value="TABLE"), + score=0.1, + document=None, + ) + hit = _hit_from_scored_chunk(scored) + assert hit.chunk_type == ChunkType.TABLE diff --git a/tests/test_server.py b/tests/test_server.py index e4d6b07..565ecf0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,17 +1,21 @@ """Server-level smoke tests for the KS MCP server.""" - EXPECTED_TOOLS = { "search_knowledge", "search_keyword", "read", "read_around", + "cite", + "ask", "list_contents", "find", "get_info", "view_chunk_image", "get_organization_info", "get_current_datetime", + # Phase 2 provenance: + "trace_chunk_lineage", + "compare_versions", } @@ -20,8 +24,11 @@ def test_build_server_metadata() -> None: mcp = build_server(host="0.0.0.0", port=9999) assert mcp.name == "knowledgestack" - assert "search_knowledge" in mcp.instructions - assert "read_around" in mcp.instructions + instructions = mcp.instructions or "" + assert "search_knowledge" in instructions + assert "read_around" in instructions + assert "cite" in instructions + assert "ask" in instructions async def test_all_tools_registered() -> None: @@ -30,9 +37,7 @@ async def test_all_tools_registered() -> None: mcp = build_server() tools = await mcp.list_tools() names = {t.name for t in tools} - assert EXPECTED_TOOLS.issubset(names), ( - f"Missing tools: {EXPECTED_TOOLS - names}" - ) + assert EXPECTED_TOOLS.issubset(names), f"Missing tools: {EXPECTED_TOOLS - names}" async def test_every_tool_has_non_empty_description() -> None: