From f92c4a0a8daef84883df0b06c192d22568afde12 Mon Sep 17 00:00:00 2001 From: wuwangzhang1216 Date: Fri, 15 May 2026 00:01:51 -0400 Subject: [PATCH] Add lightweight code context indexing --- mcp_server/client.py | 67 ++++++ mcp_server/models.py | 14 ++ mcp_server/server.py | 28 ++- opendb/cli.py | 47 ++++ opendb_core/main.py | 2 + opendb_core/routers/context.py | 26 +++ opendb_core/services/context_service.py | 59 +++++ opendb_core/storage/base.py | 10 + opendb_core/storage/postgres.py | 237 +++++++++++++++++++- opendb_core/storage/sqlite.py | 258 ++++++++++++++++++++- opendb_core/utils/code_intel.py | 283 ++++++++++++++++++++++++ opendb_core/workspace.py | 14 ++ sql/schema.sql | 24 ++ tests/test_code_context.py | 214 ++++++++++++++++++ 14 files changed, 1279 insertions(+), 4 deletions(-) create mode 100644 opendb_core/routers/context.py create mode 100644 opendb_core/services/context_service.py create mode 100644 opendb_core/utils/code_intel.py create mode 100644 tests/test_code_context.py diff --git a/mcp_server/client.py b/mcp_server/client.py index 6606656..6b2e666 100644 --- a/mcp_server/client.py +++ b/mcp_server/client.py @@ -197,6 +197,73 @@ async def search( return "\n".join(lines_out) +async def context( + query: str, + limit: int = 8, + include_snippets: bool = True, +) -> str: + """Call POST /context and format a compact agent context bundle.""" + client = await get_client() + response = await client.post( + "/context", + json={ + "query": query, + "limit": limit, + "include_snippets": include_snippets, + }, + ) + + if response.status_code != 200: + return _handle_error(response) + + data = response.json() + symbols = data.get("symbols", []) + snippets = data.get("snippets", []) + related_documents = data.get("related_documents", []) + + lines_out = [f"Context for '{query}'"] + + if symbols: + lines_out.append("") + lines_out.append("Symbols:") + for s in symbols: + location = s.get("source_path") or s.get("filename") + lines_out.append( + f" {s.get('qualified_name')} [{s.get('kind')}] " + f"{location}:{s.get('start_line')}-{s.get('end_line')}" + ) + if s.get("signature"): + lines_out.append(f" {s['signature']}") + + if snippets: + lines_out.append("") + lines_out.append("Snippets:") + for snip in snippets: + location = snip.get("source_path") or snip.get("filename") + lines_out.append( + f"--- {location}:{snip.get('start_line')}-{snip.get('end_line')} " + f"({snip.get('symbol')})" + ) + lines_out.append(snip.get("text", "")) + + if related_documents: + lines_out.append("") + lines_out.append("Related documents:") + for doc in related_documents: + lines_out.append( + f" {doc.get('filename')} page {doc.get('page_number')} " + f"score {doc.get('relevance_score')}" + ) + if doc.get("highlight"): + lines_out.append(f" {doc['highlight']}") + + if not symbols and not related_documents: + lines_out.append("") + lines_out.append("No context found.") + + return "\n".join(lines_out) + + async def get_info() -> str: """Call GET /info and format as readable text.""" client = await get_client() diff --git a/mcp_server/models.py b/mcp_server/models.py index 2f97fa9..81aa8e6 100644 --- a/mcp_server/models.py +++ b/mcp_server/models.py @@ -58,6 +58,20 @@ class SearchInput(BaseModel): offset: int = Field(0, description="Pagination offset", ge=0) +class ContextInput(BaseModel): + """Input for compact agent-oriented context lookup.""" + + model_config = ConfigDict(str_strip_whitespace=True, extra="forbid") + + query: str = Field( + ..., description="Task, symbol, or topic to build compact context for", min_length=1 + ) + limit: int = Field(8, description="Max symbols/results", ge=1, le=20) + include_snippets: bool = Field( + True, description="Include small source snippets around matching symbols" + ) + + class GlobInput(BaseModel): """Input for finding files matching a glob pattern.""" diff --git a/mcp_server/server.py b/mcp_server/server.py index 632977f..e4c9bdb 100644 --- a/mcp_server/server.py +++ b/mcp_server/server.py @@ -1,4 +1,4 @@ -"""OpenDB MCP Server — 3 tools: read, search, glob.""" +"""OpenDB MCP Server.""" from __future__ import annotations @@ -10,7 +10,7 @@ from mcp_server.client import close_client from mcp_server.models import ( - AddWorkspaceInput, CurrentWorkspaceInput, GlobInput, InfoInput, + AddWorkspaceInput, ContextInput, CurrentWorkspaceInput, GlobInput, InfoInput, ListWorkspacesInput, MemoryForgetInput, MemoryRecallInput, MemoryStoreInput, ReadInput, RemoveWorkspaceInput, SearchInput, UseWorkspaceInput, ) @@ -181,6 +181,30 @@ async def opendb_search(params: SearchInput) -> str: ) +@mcp.tool( + name="opendb_context", + annotations={ + "title": "Build Context", + "readOnlyHint": True, + "destructiveHint": False, + "idempotentHint": True, + "openWorldHint": False, + }, +) +async def opendb_context(params: ContextInput) -> str: + """Build compact context from indexed code symbols plus full-text hits. + + Use this after opendb_search finds a concept, or directly when you know a + symbol/function/class name. It returns locations and small snippets instead + of dumping entire files into the agent context. + """ + return await opendb.context( + query=params.query, + limit=params.limit, + include_snippets=params.include_snippets, + ) + + @mcp.tool( name="opendb_glob", annotations={ diff --git a/opendb/cli.py b/opendb/cli.py index 4c02303..15439cf 100644 --- a/opendb/cli.py +++ b/opendb/cli.py @@ -5,6 +5,7 @@ opendb init [PATH] # create .opendb/ in PATH (default: current dir) opendb index [PATH] # index PATH (default: current dir) opendb search QUERY # search indexed files + opendb context QUERY # build compact code/document context opendb read FILENAME # read a file opendb memory profile # render a white-box memory profile opendb serve-mcp # start MCP server (stdio, embedded mode) @@ -136,6 +137,52 @@ async def _search() -> dict: typer.echo("") +# --------------------------------------------------------------------------- +# context +# --------------------------------------------------------------------------- + +@app.command() +def context( + query: str = typer.Argument(..., help="Task, symbol, or topic"), + workspace: Path = typer.Option(Path("."), "--workspace", "-w", help="Workspace root"), + limit: int = typer.Option(8, help="Maximum symbols/results"), + json_output: bool = typer.Option(False, "--json", help="Output as JSON"), +) -> None: + """Build compact context from indexed code symbols and documents.""" + from opendb_core.workspace import Workspace + + ws = Workspace.open(workspace) + + async def _context() -> dict: + await ws.init() + result = await ws.context(query, limit=limit) + await ws.close() + return result + + result = _run(_context()) + + if json_output: + typer.echo(json.dumps(result, indent=2, ensure_ascii=False)) + return + + typer.echo(f"Context for '{query}':\n") + for s in result.get("symbols", []): + location = s.get("source_path") or s.get("filename") + typer.echo( + f" {s.get('qualified_name')} [{s.get('kind')}] " + f"{location}:{s.get('start_line')}-{s.get('end_line')}" + ) + if result.get("snippets"): + typer.echo("") + for snip in result["snippets"]: + location = snip.get("source_path") or snip.get("filename") + typer.echo( + f"--- {location}:{snip.get('start_line')}-{snip.get('end_line')}" + ) + typer.echo(snip.get("text", "")) + typer.echo("") + + # --------------------------------------------------------------------------- # read # --------------------------------------------------------------------------- diff --git a/opendb_core/main.py b/opendb_core/main.py index 6694628..e54def1 100644 --- a/opendb_core/main.py +++ b/opendb_core/main.py @@ -9,6 +9,7 @@ from opendb_core.config import settings from opendb_core.services.watch_service import stop_all as stop_all_watchers from opendb_core.routers.files import router as files_router +from opendb_core.routers.context import router as context_router from opendb_core.routers.glob import router as glob_router from opendb_core.routers.health import router as health_router from opendb_core.routers.info import router as info_router @@ -89,6 +90,7 @@ async def value_error_handler(request: Request, exc: ValueError) -> JSONResponse app.include_router(health_router) app.include_router(info_router) app.include_router(files_router) +app.include_router(context_router) app.include_router(glob_router) app.include_router(index_router) app.include_router(read_router) diff --git a/opendb_core/routers/context.py b/opendb_core/routers/context.py new file mode 100644 index 0000000..149c6f4 --- /dev/null +++ b/opendb_core/routers/context.py @@ -0,0 +1,26 @@ +"""Context endpoint: compact agent-oriented search bundle.""" + +from __future__ import annotations + +from fastapi import APIRouter +from pydantic import BaseModel, Field + +from opendb_core.services.context_service import build_context + +router = APIRouter(tags=["context"]) + + +class ContextRequest(BaseModel): + query: str = Field(..., min_length=1) + limit: int = Field(8, ge=1, le=20) + include_snippets: bool = True + + +@router.post("/context") +async def context(request: ContextRequest) -> dict: + """Build compact code/document context for an agent task or symbol lookup.""" + return await build_context( + query=request.query, + limit=request.limit, + include_snippets=request.include_snippets, + ) diff --git a/opendb_core/services/context_service.py b/opendb_core/services/context_service.py new file mode 100644 index 0000000..0bdc222 --- /dev/null +++ b/opendb_core/services/context_service.py @@ -0,0 +1,59 @@ +"""Compact context builder for agent code/document exploration.""" + +from __future__ import annotations + +from opendb_core.storage import get_backend +from opendb_core.utils.text import extract_lines + + +async def build_context(query: str, limit: int = 8, include_snippets: bool = True) -> dict: + """Build a small context bundle from symbols plus full-text hits.""" + backend = get_backend() + symbol_limit = max(1, min(limit, 20)) + symbols = await backend.search_code_symbols(query, limit=symbol_limit) + + snippets: list[dict] = [] + if include_snippets: + seen: set[tuple[str, int, int]] = set() + for symbol in symbols[: min(symbol_limit, 8)]: + file_id = str(symbol["file_id"]) + start = max(1, int(symbol["start_line"]) - 1) + end = int(symbol["end_line"]) + 1 + key = (file_id, start, end) + if key in seen: + continue + seen.add(key) + text_row = await backend.get_file_text(file_id) + snippets.append({ + "file_id": file_id, + "filename": symbol["filename"], + "source_path": symbol.get("source_path"), + "symbol": symbol["qualified_name"], + "start_line": start, + "end_line": end, + "text": extract_lines( + text_row["full_text"], + text_row["line_index"], + start, + min(end, int(text_row["total_lines"])), + ), + }) + + doc_hits = await backend.search_fts(query, {}, max(1, min(limit, 10)), 0) + symbol_file_ids = {str(s["file_id"]) for s in symbols} + related_documents = [ + hit for hit in doc_hits.get("results", []) + if str(hit.get("file_id")) not in symbol_file_ids + ][: max(0, limit - len(symbols[:limit]))] + + return { + "query": query, + "symbols": symbols[:limit], + "snippets": snippets, + "related_documents": related_documents, + "stats": { + "symbol_count": len(symbols), + "snippet_count": len(snippets), + "related_document_count": len(related_documents), + }, + } diff --git a/opendb_core/storage/base.py b/opendb_core/storage/base.py index 941311d..f1b255f 100644 --- a/opendb_core/storage/base.py +++ b/opendb_core/storage/base.py @@ -137,6 +137,16 @@ async def search_fts( """ ... + async def search_code_symbols( + self, + query: str, + *, + limit: int = 20, + kinds: list[str] | None = None, + ) -> list[dict]: + """Search indexed code symbols by name/signature/docstring.""" + ... + # ------------------------------------------------------------------ # Indexing # ------------------------------------------------------------------ diff --git a/opendb_core/storage/postgres.py b/opendb_core/storage/postgres.py index 311baf2..accc345 100644 --- a/opendb_core/storage/postgres.py +++ b/opendb_core/storage/postgres.py @@ -9,6 +9,7 @@ import json import logging +from types import SimpleNamespace logger = logging.getLogger(__name__) @@ -30,6 +31,7 @@ async def init(self) -> None: """Run lightweight schema migrations (add columns if missing).""" await self._migrate_cjk_columns() await self._migrate_eval_and_links() + await self._backfill_code_symbols_if_needed() async def close(self) -> None: pass # Pool lifecycle owned by app/database.py @@ -150,6 +152,122 @@ async def _migrate_eval_and_links(self) -> None: await conn.execute( "CREATE INDEX IF NOT EXISTS idx_file_links_target ON file_links(target)" ) + await conn.execute( + """ + CREATE TABLE IF NOT EXISTS code_symbols ( + id BIGSERIAL PRIMARY KEY, + file_id UUID NOT NULL REFERENCES files(id) ON DELETE CASCADE, + name TEXT NOT NULL, + kind TEXT NOT NULL, + qualified_name TEXT NOT NULL, + start_line INTEGER NOT NULL, + end_line INTEGER NOT NULL, + signature TEXT NOT NULL DEFAULT '', + docstring TEXT NOT NULL DEFAULT '', + tsv TSVECTOR GENERATED ALWAYS AS ( + to_tsvector('english', name || ' ' || qualified_name || ' ' || signature || ' ' || docstring) + ) STORED, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() + ) + """ + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_code_symbols_file ON code_symbols(file_id)" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_code_symbols_name ON code_symbols(name)" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_code_symbols_kind ON code_symbols(kind)" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_code_symbols_tsv ON code_symbols USING GIN(tsv)" + ) + + async def _backfill_code_symbols_if_needed(self) -> None: + """Populate code_symbols for ready code files indexed before this table existed.""" + from opendb_core.database import get_pool + + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT + f.id::text AS file_id, + f.filename, + f.metadata, + p.page_number, + p.text, + p.line_start, + p.line_end + FROM files f + JOIN pages p ON p.file_id = f.id + WHERE f.status = 'ready' + AND NOT EXISTS ( + SELECT 1 FROM code_symbols s WHERE s.file_id = f.id + ) + AND ( + lower(f.filename) LIKE '%.py' + OR lower(f.filename) LIKE '%.js' + OR lower(f.filename) LIKE '%.jsx' + OR lower(f.filename) LIKE '%.ts' + OR lower(f.filename) LIKE '%.tsx' + OR lower(f.metadata->>'source_path') LIKE '%.py' + OR lower(f.metadata->>'source_path') LIKE '%.js' + OR lower(f.metadata->>'source_path') LIKE '%.jsx' + OR lower(f.metadata->>'source_path') LIKE '%.ts' + OR lower(f.metadata->>'source_path') LIKE '%.tsx' + ) + ORDER BY f.id, p.page_number + """ + ) + if not rows: + return + + from opendb_core.config import settings + from opendb_core.utils.code_intel import extract_code_intel_from_pages + + current_file_id = None + group = [] + backfilled = 0 + + async def flush_group(items) -> None: + nonlocal backfilled + if not items: + return + first = items[0] + metadata = first["metadata"] + if isinstance(metadata, str): + metadata = json.loads(metadata) + source_path = (metadata or {}).get("source_path") or first["filename"] + pages = [ + SimpleNamespace(page_number=row["page_number"], text=row["text"]) + for row in items + ] + line_ranges = [(row["line_start"], row["line_end"]) for row in items] + symbols, code_links = extract_code_intel_from_pages( + pages, + line_ranges, + filename=first["filename"], + source_path=source_path, + ) + await self._replace_code_symbols_unlocked(conn, first["file_id"], symbols) + if settings.link_extraction_enabled and code_links: + await self._insert_file_links_unlocked(conn, first["file_id"], code_links) + backfilled += 1 + + async with conn.transaction(): + for row in rows: + if current_file_id is None: + current_file_id = row["file_id"] + if row["file_id"] != current_file_id: + await flush_group(group) + group = [] + current_file_id = row["file_id"] + group.append(row) + await flush_group(group) + if backfilled: + logger.info("Backfilled code symbols for %d existing code files.", backfilled) # ------------------------------------------------------------------ # Ingestion @@ -252,10 +370,19 @@ async def persist_ingestion( ) from opendb_core.config import settings + source_path = merged_metadata.get("source_path") or original_filename + from opendb_core.utils.code_intel import extract_code_intel_from_pages + symbols, code_links = extract_code_intel_from_pages( + parse_result.pages, + page_line_ranges, + filename=original_filename, + source_path=source_path, + ) + await self._replace_code_symbols_unlocked(conn, str(file_uuid), symbols) if settings.link_extraction_enabled: - source_path = merged_metadata.get("source_path") or original_filename from opendb_core.utils.link_extractor import extract_file_links links = extract_file_links(full_text, source_path=source_path) + links.extend(code_links) await self._replace_file_links_unlocked(conn, str(file_uuid), links) await conn.execute( @@ -870,6 +997,12 @@ async def _replace_file_links_unlocked(self, conn, file_id: str, links: list[dic file_uuid = _uuid.UUID(file_id) await conn.execute("DELETE FROM file_links WHERE from_file_id = $1", file_uuid) + return await self._insert_file_links_unlocked(conn, file_id, links) + + async def _insert_file_links_unlocked(self, conn, file_id: str, links: list[dict]) -> int: + import uuid as _uuid + + file_uuid = _uuid.UUID(file_id) rows = [] for link in links: target = str(link.get("target", "")).strip() @@ -962,6 +1095,108 @@ async def get_backlink_counts(self, file_ids: list[str]) -> dict[str, int]: ) return {str(r["to_file_id"]): int(r["cnt"]) for r in rows} + async def _replace_code_symbols_unlocked(self, conn, file_id: str, symbols: list[dict]) -> int: + import uuid as _uuid + + file_uuid = _uuid.UUID(file_id) + await conn.execute("DELETE FROM code_symbols WHERE file_id = $1", file_uuid) + rows = [] + for symbol in symbols: + name = str(symbol.get("name") or "").strip() + if not name: + continue + rows.append(( + file_uuid, + name, + str(symbol.get("kind") or "symbol"), + str(symbol.get("qualified_name") or name), + int(symbol.get("start_line") or 1), + int(symbol.get("end_line") or symbol.get("start_line") or 1), + str(symbol.get("signature") or "")[:1000], + str(symbol.get("docstring") or "")[:2000], + )) + if rows: + await conn.executemany( + """ + INSERT INTO code_symbols + (file_id, name, kind, qualified_name, start_line, end_line, signature, docstring) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + """, + rows, + ) + return len(rows) + + async def search_code_symbols( + self, + query: str, + *, + limit: int = 20, + kinds: list[str] | None = None, + ) -> list[dict]: + from opendb_core.database import get_pool + pool = await get_pool() + async with pool.acquire() as conn: + kind_clause = "" + params: list = [query, limit] + if kinds: + kind_clause = "AND s.kind = ANY($3::text[])" + params.append(kinds) + rows = await conn.fetch( + f""" + SELECT + s.file_id::text AS file_id, + f.filename, + f.metadata->>'source_path' AS source_path, + s.name, + s.kind, + s.qualified_name, + s.start_line, + s.end_line, + s.signature, + s.docstring, + ts_rank(s.tsv, plainto_tsquery('english', $1)) AS rank + FROM code_symbols s + JOIN files f ON f.id = s.file_id + WHERE s.tsv @@ plainto_tsquery('english', $1) + AND f.status = 'ready' + {kind_clause} + ORDER BY rank DESC, length(s.qualified_name), s.start_line + LIMIT $2 + """, + *params, + ) + if not rows: + kind_clause = "" + params = [f"%{query}%", limit] + if kinds: + kind_clause = "AND s.kind = ANY($3::text[])" + params.append(kinds) + rows = await conn.fetch( + f""" + SELECT + s.file_id::text AS file_id, + f.filename, + f.metadata->>'source_path' AS source_path, + s.name, + s.kind, + s.qualified_name, + s.start_line, + s.end_line, + s.signature, + s.docstring, + 0.0 AS rank + FROM code_symbols s + JOIN files f ON f.id = s.file_id + WHERE (s.name ILIKE $1 OR s.qualified_name ILIKE $1) + AND f.status = 'ready' + {kind_clause} + ORDER BY length(s.qualified_name), s.start_line + LIMIT $2 + """, + *params, + ) + return [dict(r) for r in rows] + # --------------------------------------------------------------------------- # Helpers diff --git a/opendb_core/storage/sqlite.py b/opendb_core/storage/sqlite.py index 98ea404..c06a73f 100644 --- a/opendb_core/storage/sqlite.py +++ b/opendb_core/storage/sqlite.py @@ -17,6 +17,7 @@ import json import logging from pathlib import Path +from types import SimpleNamespace from opendb_core.storage.shared import ( build_highlight, @@ -153,6 +154,48 @@ CREATE INDEX IF NOT EXISTS idx_file_links_from ON file_links(from_file_id); CREATE INDEX IF NOT EXISTS idx_file_links_to ON file_links(to_file_id); CREATE INDEX IF NOT EXISTS idx_file_links_target ON file_links(target); + +-- ----------------------------------------------------------------- +-- Lightweight code symbols +-- ----------------------------------------------------------------- +CREATE TABLE IF NOT EXISTS code_symbols ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_id TEXT NOT NULL REFERENCES files(id) ON DELETE CASCADE, + name TEXT NOT NULL, + kind TEXT NOT NULL, + qualified_name TEXT NOT NULL, + start_line INTEGER NOT NULL, + end_line INTEGER NOT NULL, + signature TEXT NOT NULL DEFAULT '', + docstring TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')) +); + +CREATE INDEX IF NOT EXISTS idx_code_symbols_file ON code_symbols(file_id); +CREATE INDEX IF NOT EXISTS idx_code_symbols_name ON code_symbols(name); +CREATE INDEX IF NOT EXISTS idx_code_symbols_kind ON code_symbols(kind); + +CREATE VIRTUAL TABLE IF NOT EXISTS code_symbols_fts USING fts5( + name, + qualified_name, + signature, + docstring +); + +CREATE TRIGGER IF NOT EXISTS code_symbols_ai AFTER INSERT ON code_symbols BEGIN + INSERT INTO code_symbols_fts(rowid, name, qualified_name, signature, docstring) + VALUES (NEW.id, NEW.name, NEW.qualified_name, NEW.signature, NEW.docstring); +END; + +CREATE TRIGGER IF NOT EXISTS code_symbols_ad AFTER DELETE ON code_symbols BEGIN + DELETE FROM code_symbols_fts WHERE rowid = OLD.id; +END; + +CREATE TRIGGER IF NOT EXISTS code_symbols_au AFTER UPDATE ON code_symbols BEGIN + DELETE FROM code_symbols_fts WHERE rowid = OLD.id; + INSERT INTO code_symbols_fts(rowid, name, qualified_name, signature, docstring) + VALUES (NEW.id, NEW.name, NEW.qualified_name, NEW.signature, NEW.docstring); +END; """ @@ -188,6 +231,7 @@ async def init(self) -> None: await self._db.executescript(_SCHEMA) await self._migrate_memories_pinned() await self._migrate_memories_v2() + await self._backfill_code_symbols_if_needed() await self._db.commit() logger.info("SQLite backend initialised at %s", self._db_path) @@ -264,6 +308,88 @@ async def _migrate_memories_v2(self) -> None: except (aiosqlite.OperationalError, aiosqlite.DatabaseError): pass + async def _backfill_code_symbols_if_needed(self) -> None: + """Populate code_symbols for ready code files indexed before this table existed.""" + try: + async with self._db.execute( + """ + SELECT + f.id AS file_id, + f.filename, + f.metadata, + p.page_number, + p.text, + p.line_start, + p.line_end + FROM files f + JOIN pages p ON p.file_id = f.id + WHERE f.status = 'ready' + AND NOT EXISTS ( + SELECT 1 FROM code_symbols s WHERE s.file_id = f.id + ) + AND ( + lower(f.filename) GLOB '*.py' + OR lower(f.filename) GLOB '*.js' + OR lower(f.filename) GLOB '*.jsx' + OR lower(f.filename) GLOB '*.ts' + OR lower(f.filename) GLOB '*.tsx' + OR lower(json_extract(f.metadata, '$.source_path')) GLOB '*.py' + OR lower(json_extract(f.metadata, '$.source_path')) GLOB '*.js' + OR lower(json_extract(f.metadata, '$.source_path')) GLOB '*.jsx' + OR lower(json_extract(f.metadata, '$.source_path')) GLOB '*.ts' + OR lower(json_extract(f.metadata, '$.source_path')) GLOB '*.tsx' + ) + ORDER BY f.id, p.page_number + """ + ) as cur: + rows = await cur.fetchall() + except (aiosqlite.OperationalError, aiosqlite.DatabaseError): + return + if not rows: + return + + from opendb_core.config import settings + from opendb_core.utils.code_intel import extract_code_intel_from_pages + + current_file_id = None + group: list[aiosqlite.Row] = [] + backfilled = 0 + + async def flush_group(items: list[aiosqlite.Row]) -> None: + nonlocal backfilled + if not items: + return + first = items[0] + metadata = json.loads(first["metadata"]) if first["metadata"] else {} + source_path = metadata.get("source_path") or first["filename"] + pages = [ + SimpleNamespace(page_number=row["page_number"], text=row["text"]) + for row in items + ] + line_ranges = [(row["line_start"], row["line_end"]) for row in items] + symbols, code_links = extract_code_intel_from_pages( + pages, + line_ranges, + filename=first["filename"], + source_path=source_path, + ) + await self._replace_code_symbols_unlocked(first["file_id"], symbols) + if settings.link_extraction_enabled and code_links: + await self._insert_file_links_unlocked(first["file_id"], code_links) + backfilled += 1 + + for row in rows: + if current_file_id is None: + current_file_id = row["file_id"] + if row["file_id"] != current_file_id: + await flush_group(group) + group = [] + current_file_id = row["file_id"] + group.append(row) + await flush_group(group) + if backfilled: + logger.info("Backfilled code symbols for %d existing code files.", backfilled) + async def close(self) -> None: if self._db: await self._db.close() @@ -383,10 +509,19 @@ async def persist_ingestion( ) from opendb_core.config import settings + source_path = merged_metadata.get("source_path") or original_filename + from opendb_core.utils.code_intel import extract_code_intel_from_pages + symbols, code_links = extract_code_intel_from_pages( + parse_result.pages, + page_line_ranges, + filename=original_filename, + source_path=source_path, + ) + await self._replace_code_symbols_unlocked(file_id, symbols) if settings.link_extraction_enabled: - source_path = merged_metadata.get("source_path") or original_filename from opendb_core.utils.link_extractor import extract_file_links links = extract_file_links(full_text, source_path=source_path) + links.extend(code_links) await self._replace_file_links_unlocked(file_id, links) await self._db.execute( @@ -919,6 +1054,9 @@ async def _resolve_link_target(self, target: str) -> str | None: async def _replace_file_links_unlocked(self, file_id: str, links: list[dict]) -> int: await self._db.execute("DELETE FROM file_links WHERE from_file_id = ?", (file_id,)) + return await self._insert_file_links_unlocked(file_id, links) + + async def _insert_file_links_unlocked(self, file_id: str, links: list[dict]) -> int: rows = [] for link in links: target = str(link.get("target", "")).strip() @@ -999,3 +1137,121 @@ async def get_backlink_counts(self, file_ids: list[str]) -> dict[str, int]: ) as cur: rows = await cur.fetchall() return {r["to_file_id"]: int(r["cnt"]) for r in rows} + + async def _replace_code_symbols_unlocked( + self, + file_id: str, + symbols: list[dict], + ) -> int: + await self._db.execute("DELETE FROM code_symbols WHERE file_id = ?", (file_id,)) + + rows = [] + for symbol in symbols: + name = str(symbol.get("name") or "").strip() + if not name: + continue + rows.append(( + file_id, + name, + str(symbol.get("kind") or "symbol"), + str(symbol.get("qualified_name") or name), + int(symbol.get("start_line") or 1), + int(symbol.get("end_line") or symbol.get("start_line") or 1), + str(symbol.get("signature") or "")[:1000], + str(symbol.get("docstring") or "")[:2000], + )) + if not rows: + return 0 + await self._db.executemany( + """ + INSERT INTO code_symbols + (file_id, name, kind, qualified_name, start_line, end_line, signature, docstring) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + rows, + ) + return len(rows) + + async def search_code_symbols( + self, + query: str, + *, + limit: int = 20, + kinds: list[str] | None = None, + ) -> list[dict]: + from opendb_core.utils.tokenizer import tokenize_for_fts + + match_query = tokenize_for_fts(query) + params: list[object] = [match_query] + kind_clause = "" + if kinds: + placeholders = ",".join("?" for _ in kinds) + kind_clause = f"AND s.kind IN ({placeholders})" + params.extend(kinds) + params.append(limit) + try: + async with self._db.execute( + f""" + SELECT + s.file_id, + f.filename, + json_extract(f.metadata, '$.source_path') AS source_path, + s.name, + s.kind, + s.qualified_name, + s.start_line, + s.end_line, + s.signature, + s.docstring, + bm25(code_symbols_fts) AS rank + FROM code_symbols_fts + JOIN code_symbols s ON s.id = code_symbols_fts.rowid + JOIN files f ON f.id = s.file_id + WHERE code_symbols_fts MATCH ? + AND f.status = 'ready' + {kind_clause} + ORDER BY rank + LIMIT ? + """, + params, + ) as cur: + rows = await cur.fetchall() + except aiosqlite.OperationalError: + rows = [] + if rows: + return [dict(r) for r in rows] + else: + like = f"%{query}%" + params = [like, like] + kind_clause = "" + if kinds: + placeholders = ",".join("?" for _ in kinds) + kind_clause = f"AND s.kind IN ({placeholders})" + params.extend(kinds) + params.append(limit) + async with self._db.execute( + f""" + SELECT + s.file_id, + f.filename, + json_extract(f.metadata, '$.source_path') AS source_path, + s.name, + s.kind, + s.qualified_name, + s.start_line, + s.end_line, + s.signature, + s.docstring, + 0.0 AS rank + FROM code_symbols s + JOIN files f ON f.id = s.file_id + WHERE (s.name LIKE ? OR s.qualified_name LIKE ?) + AND f.status = 'ready' + {kind_clause} + ORDER BY length(s.qualified_name), s.start_line + LIMIT ? + """, + params, + ) as cur: + rows = await cur.fetchall() + return [dict(r) for r in rows] diff --git a/opendb_core/utils/code_intel.py b/opendb_core/utils/code_intel.py new file mode 100644 index 0000000..3390db4 --- /dev/null +++ b/opendb_core/utils/code_intel.py @@ -0,0 +1,283 @@ +"""Small deterministic code intelligence helpers. + +This is intentionally narrower than a full code graph. It extracts enough +structure to make local agent lookups cheaper: Python symbols and import links, +plus a few regex-based symbols for JS/TS-style files. +""" + +from __future__ import annotations + +import ast +import posixpath +import re +from pathlib import Path + + +_PY_EXTENSIONS = {".py"} +_JS_EXTENSIONS = {".js", ".jsx", ".ts", ".tsx", ".mjs", ".cjs"} +_CODE_EXTENSIONS = _PY_EXTENSIONS | _JS_EXTENSIONS + +_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?(?:default\s+)?class\s+([A-Za-z_$][\w$]*)", re.M) +_JS_FUNCTION_RE = re.compile( + r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(", + re.M, +) +_JS_CONST_FN_RE = re.compile( + r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(", + re.M, +) +_JS_IMPORT_RE = re.compile( + r"""(?:import\s+(?:.+?\s+from\s+)?|export\s+.+?\s+from\s+)['"]([^'"]+)['"]|require\(['"]([^'"]+)['"]\)""", + re.S, +) + + +def is_code_path(filename: str, source_path: str | None = None) -> bool: + suffix = Path(source_path or filename).suffix.lower() + return suffix in _CODE_EXTENSIONS + + +def extract_code_intel( + content: str, + *, + filename: str, + source_path: str | None = None, +) -> tuple[list[dict], list[dict]]: + """Return ``(symbols, links)`` extracted from source text. + + Symbols are dictionaries ready for storage. Links use the same shape as + ``extract_file_links`` so they can be stored in the existing file_links + table. + """ + path = (source_path or filename).replace("\\", "/") + suffix = Path(path).suffix.lower() + if suffix == ".py": + return _extract_python(content, path) + if suffix in _JS_EXTENSIONS: + return _extract_js_like(content, path) + return [], [] + + +def extract_code_intel_from_pages( + pages, + page_line_ranges: list[tuple[int, int]], + *, + filename: str, + source_path: str | None = None, +) -> tuple[list[dict], list[dict]]: + """Extract code intel from parser pages while preserving assembled line numbers.""" + all_symbols: list[dict] = [] + all_links: list[dict] = [] + for index, page in enumerate(pages): + symbols, links = extract_code_intel( + page.text, + filename=filename, + source_path=source_path, + ) + offset = 0 + if index < len(page_line_ranges): + offset = page_line_ranges[index][0] - 1 + for symbol in symbols: + adjusted = dict(symbol) + adjusted["start_line"] = int(adjusted["start_line"]) + offset + adjusted["end_line"] = int(adjusted["end_line"]) + offset + all_symbols.append(adjusted) + all_links.extend(links) + return all_symbols, _dedupe_links(all_links) + + +def _symbol( + *, + name: str, + kind: str, + qualified_name: str, + start_line: int, + end_line: int, + signature: str = "", + docstring: str = "", +) -> dict: + return { + "name": name, + "kind": kind, + "qualified_name": qualified_name, + "start_line": start_line, + "end_line": end_line, + "signature": signature, + "docstring": docstring, + } + + +def _extract_python(content: str, source_path: str) -> tuple[list[dict], list[dict]]: + try: + tree = ast.parse(content) + except SyntaxError: + return [], [] + + symbols: list[dict] = [] + links: list[dict] = [] + + def signature(node: ast.FunctionDef | ast.AsyncFunctionDef) -> str: + args = [arg.arg for arg in node.args.posonlyargs + node.args.args] + if node.args.vararg: + args.append(f"*{node.args.vararg.arg}") + args.extend(arg.arg for arg in node.args.kwonlyargs) + if node.args.kwarg: + args.append(f"**{node.args.kwarg.arg}") + return f"{node.name}({', '.join(args)})" + + class Visitor(ast.NodeVisitor): + def __init__(self) -> None: + self.stack: list[str] = [] + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + qualified = ".".join([*self.stack, node.name]) + symbols.append(_symbol( + name=node.name, + kind="class", + qualified_name=qualified, + start_line=node.lineno, + end_line=getattr(node, "end_lineno", node.lineno), + signature=node.name, + docstring=ast.get_docstring(node) or "", + )) + self.stack.append(node.name) + self.generic_visit(node) + self.stack.pop() + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + self._visit_function(node, "method" if self.stack else "function") + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + self._visit_function(node, "method" if self.stack else "function") + + def _visit_function( + self, + node: ast.FunctionDef | ast.AsyncFunctionDef, + kind: str, + ) -> None: + qualified = ".".join([*self.stack, node.name]) + symbols.append(_symbol( + name=node.name, + kind=kind, + qualified_name=qualified, + start_line=node.lineno, + end_line=getattr(node, "end_lineno", node.lineno), + signature=signature(node), + docstring=ast.get_docstring(node) or "", + )) + self.stack.append(node.name) + self.generic_visit(node) + self.stack.pop() + + Visitor().visit(tree) + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + links.extend(_python_import_targets(alias.name, source_path, node.lineno)) + elif isinstance(node, ast.ImportFrom): + module = "." * node.level + (node.module or "") + links.extend(_python_import_targets(module, source_path, node.lineno)) + + return symbols, _dedupe_links(links) + + +def _python_import_targets(module: str, source_path: str, line: int) -> list[dict]: + if not module: + return [] + if module.startswith("."): + target_base = _resolve_relative_python_module(module, source_path) + else: + target_base = module.replace(".", "/") + if not target_base: + return [] + return [ + { + "target": f"{target_base}.py", + "link_type": "import", + "context": f"import {module} (line {line})", + }, + { + "target": f"{target_base}/__init__.py", + "link_type": "import", + "context": f"import {module} (line {line})", + }, + ] + + +def _resolve_relative_python_module(module: str, source_path: str) -> str | None: + dots = len(module) - len(module.lstrip(".")) + rest = module.lstrip(".") + directory = posixpath.dirname(source_path) + parts = directory.split("/") if directory else [] + keep = max(0, len(parts) - max(0, dots - 1)) + base = "/".join(parts[:keep]) + if rest: + rest_path = rest.replace(".", "/") + return posixpath.normpath(posixpath.join(base, rest_path)) + return posixpath.normpath(base) if base else None + + +def _extract_js_like(content: str, source_path: str) -> tuple[list[dict], list[dict]]: + symbols: list[dict] = [] + for kind, regex in [ + ("class", _JS_CLASS_RE), + ("function", _JS_FUNCTION_RE), + ("function", _JS_CONST_FN_RE), + ]: + for match in regex.finditer(content): + line = content.count("\n", 0, match.start()) + 1 + name = match.group(1) + symbols.append(_symbol( + name=name, + kind=kind, + qualified_name=name, + start_line=line, + end_line=line, + signature=name, + )) + + links: list[dict] = [] + for match in _JS_IMPORT_RE.finditer(content): + target = match.group(1) or match.group(2) + if not target or not target.startswith("."): + continue + line = content.count("\n", 0, match.start()) + 1 + links.extend(_js_import_targets(target, source_path, line)) + + return symbols, _dedupe_links(links) + + +def _js_import_targets(target: str, source_path: str, line: int) -> list[dict]: + base = posixpath.normpath(posixpath.join(posixpath.dirname(source_path), target)) + suffix = Path(base).suffix + candidates = [base] if suffix else [ + f"{base}.ts", + f"{base}.tsx", + f"{base}.js", + f"{base}.jsx", + f"{base}/index.ts", + f"{base}/index.tsx", + f"{base}/index.js", + f"{base}/index.jsx", + ] + return [ + { + "target": candidate, + "link_type": "import", + "context": f"import {target} (line {line})", + } + for candidate in candidates + ] + + +def _dedupe_links(links: list[dict]) -> list[dict]: + seen: set[tuple[str, str]] = set() + unique: list[dict] = [] + for link in links: + key = (link["target"], link["link_type"]) + if key in seen: + continue + seen.add(key) + unique.append(link) + return unique diff --git a/opendb_core/workspace.py b/opendb_core/workspace.py index 733c9e5..184cf3b 100644 --- a/opendb_core/workspace.py +++ b/opendb_core/workspace.py @@ -216,6 +216,20 @@ async def search( from opendb_core.services.search_service import search_files return await search_files(query=query, limit=limit, offset=offset) + async def context( + self, + query: str, + limit: int = 8, + include_snippets: bool = True, + ) -> dict: + """Build compact agent context from indexed symbols and FTS hits.""" + from opendb_core.services.context_service import build_context + return await build_context( + query=query, + limit=limit, + include_snippets=include_snippets, + ) + async def info(self) -> dict: """Return workspace statistics.""" from opendb_core.storage import get_backend diff --git a/sql/schema.sql b/sql/schema.sql index 47758e8..092c092 100644 --- a/sql/schema.sql +++ b/sql/schema.sql @@ -149,3 +149,27 @@ CREATE TABLE file_links ( CREATE INDEX idx_file_links_from ON file_links(from_file_id); CREATE INDEX idx_file_links_to ON file_links(to_file_id); CREATE INDEX idx_file_links_target ON file_links(target); + +-- ============================================================ +-- code_symbols: lightweight code outline for local agent context +-- ============================================================ +CREATE TABLE code_symbols ( + id BIGSERIAL PRIMARY KEY, + file_id UUID NOT NULL REFERENCES files(id) ON DELETE CASCADE, + name TEXT NOT NULL, + kind TEXT NOT NULL, + qualified_name TEXT NOT NULL, + start_line INTEGER NOT NULL, + end_line INTEGER NOT NULL, + signature TEXT NOT NULL DEFAULT '', + docstring TEXT NOT NULL DEFAULT '', + tsv TSVECTOR GENERATED ALWAYS AS ( + to_tsvector('english', name || ' ' || qualified_name || ' ' || signature || ' ' || docstring) + ) STORED, + created_at TIMESTAMPTZ DEFAULT now() +); + +CREATE INDEX idx_code_symbols_file ON code_symbols(file_id); +CREATE INDEX idx_code_symbols_name ON code_symbols(name); +CREATE INDEX idx_code_symbols_kind ON code_symbols(kind); +CREATE INDEX idx_code_symbols_tsv ON code_symbols USING GIN(tsv); diff --git a/tests/test_code_context.py b/tests/test_code_context.py new file mode 100644 index 0000000..3e9683a --- /dev/null +++ b/tests/test_code_context.py @@ -0,0 +1,214 @@ +"""Tests for lightweight code symbols, import links, and context bundles.""" + +from __future__ import annotations + +import pytest + +from opendb_core.parsers.base import Page, ParseResult +from opendb_core.services.context_service import build_context +from opendb_core.storage import close_backend, get_backend, init_backend +from opendb_core.storage.sqlite import SQLiteBackend +from opendb_core.utils.text import assemble_text +from opendb_core.utils.code_intel import extract_code_intel + + +def _parse_code(text: str) -> tuple[ParseResult, str, list[int], list[tuple[int, int]]]: + result = ParseResult(pages=[Page(page_number=1, section_title=None, text=text)]) + full_text, line_index, _toc, page_line_ranges = assemble_text(result.pages, "text/x-python") + return result, full_text, line_index, page_line_ranges + + +async def _ingest_code( + backend: SQLiteBackend, + *, + file_id: str, + filename: str, + source_path: str, + text: str, +) -> None: + parse_result, full_text, line_index, page_line_ranges = _parse_code(text) + await backend.persist_ingestion( + file_id=file_id, + file_path=f"/tmp/{filename}", + original_filename=filename, + mime_type="text/x-python", + file_size=len(text), + checksum=f"checksum-{file_id}", + tags=[], + merged_metadata={"source_path": source_path}, + parse_result=parse_result, + full_text=full_text, + total_lines=len(line_index), + line_index=line_index, + toc="", + page_line_ranges=page_line_ranges, + ) + + +@pytest.fixture +async def backend(tmp_path): + db_path = tmp_path / "code-context.db" + b = SQLiteBackend(db_path=db_path) + await b.init() + yield b + await b.close() + + +def test_extracts_python_symbols_and_import_links() -> None: + symbols, links = extract_code_intel( + "from opendb_core.storage.sqlite import SQLiteBackend\n\n" + "class SearchService:\n" + " def run(self, query):\n" + " return query\n", + filename="service.py", + source_path="/workspace/app/service.py", + ) + + assert {s["qualified_name"] for s in symbols} == { + "SearchService", + "SearchService.run", + } + assert "opendb_core/storage/sqlite.py" in {link["target"] for link in links} + + +class TestCodeContext: + @pytest.mark.asyncio + async def test_ingest_indexes_symbols_for_targeted_lookup(self, backend) -> None: + await _ingest_code( + backend, + file_id="worker", + filename="worker.py", + source_path="/workspace/pkg/worker.py", + text=( + "def do_work(item):\n" + " \"\"\"Process one queued item.\"\"\"\n" + " return item.upper()\n" + ), + ) + + rows = await backend.search_code_symbols("do_work", limit=5) + + assert len(rows) == 1 + assert rows[0]["qualified_name"] == "do_work" + assert rows[0]["source_path"] == "/workspace/pkg/worker.py" + + await backend.delete_file("worker") + assert await backend.search_code_symbols("do_work", limit=5) == [] + + @pytest.mark.asyncio + async def test_python_imports_feed_existing_backlink_graph(self, backend) -> None: + await _ingest_code( + backend, + file_id="worker", + filename="worker.py", + source_path="/workspace/pkg/worker.py", + text="def do_work(item):\n return item\n", + ) + await _ingest_code( + backend, + file_id="runner", + filename="runner.py", + source_path="/workspace/pkg/runner.py", + text="from pkg.worker import do_work\n\nresult = do_work('x')\n", + ) + + assert await backend.get_backlink_counts(["worker"]) == {"worker": 1} + + @pytest.mark.asyncio + async def test_context_returns_symbol_snippets_without_reading_whole_file(self, tmp_path) -> None: + db_path = tmp_path / "context-service.db" + await init_backend("sqlite", db_path=db_path) + try: + backend = get_backend() + await _ingest_code( + backend, + file_id="worker", + filename="worker.py", + source_path="/workspace/pkg/worker.py", + text=( + "def do_work(item):\n" + " \"\"\"Process one queued item.\"\"\"\n" + " return item.upper()\n\n" + "def unrelated():\n" + " return 'noise'\n" + ), + ) + + result = await build_context("do_work", limit=3) + finally: + await close_backend(key=str(db_path)) + + assert result["stats"]["symbol_count"] == 1 + assert result["symbols"][0]["qualified_name"] == "do_work" + assert "def do_work" in result["snippets"][0]["text"] + assert "def unrelated" not in result["snippets"][0]["text"] + + @pytest.mark.asyncio + async def test_init_backfills_symbols_for_existing_indexed_code(self, tmp_path) -> None: + db_path = tmp_path / "upgrade.db" + text = "from pkg.worker import do_work\n\ndef old_symbol():\n return do_work('x')\n" + parse_result, full_text, line_index, page_line_ranges = _parse_code(text) + + backend = SQLiteBackend(db_path=db_path) + await backend.init() + try: + await backend._db.execute( + """ + INSERT INTO files + (id, filename, mime_type, file_size, file_path, checksum, status, tags, metadata) + VALUES (?, ?, ?, ?, ?, ?, 'ready', '[]', ?) + """, + ( + "old-runner", + "runner.py", + "text/x-python", + len(text), + "/tmp/runner.py", + "checksum-old-runner", + '{"source_path": "/workspace/pkg/runner.py"}', + ), + ) + await backend._db.execute( + "INSERT INTO file_text (file_id, full_text, total_lines, line_index, toc) " + "VALUES (?, ?, ?, ?, '')", + ("old-runner", full_text, len(line_index), str(line_index)), + ) + await backend._db.execute( + """ + INSERT INTO pages (file_id, page_number, text, line_start, line_end) + VALUES (?, ?, ?, ?, ?) + """, + ( + "old-runner", + 1, + parse_result.pages[0].text, + page_line_ranges[0][0], + page_line_ranges[0][1], + ), + ) + await backend._db.execute( + """ + INSERT INTO file_links (from_file_id, target, link_type, context) + VALUES (?, ?, 'markdown', 'existing link') + """, + ("old-runner", "README.md"), + ) + await backend._db.commit() + finally: + await backend.close() + + reopened = SQLiteBackend(db_path=db_path) + await reopened.init() + try: + rows = await reopened.search_code_symbols("old_symbol", limit=5) + async with reopened._db.execute( + "SELECT target, link_type FROM file_links WHERE from_file_id = ?", + ("old-runner",), + ) as cur: + links = {(row["target"], row["link_type"]) for row in await cur.fetchall()} + finally: + await reopened.close() + + assert [row["qualified_name"] for row in rows] == ["old_symbol"] + assert ("README.md", "markdown") in links + assert ("pkg/worker.py", "import") in links