diff --git a/src/connector/adapter/mcp/server.rs b/src/connector/adapter/mcp/server.rs index 5a62f13..9609cc6 100644 --- a/src/connector/adapter/mcp/server.rs +++ b/src/connector/adapter/mcp/server.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; +use std::path::Path; use std::sync::Arc; use rmcp::handler::server::tool::ToolRouter; @@ -11,8 +13,9 @@ use rmcp::tool_router; use rmcp::ErrorData as McpError; use rmcp::ServerHandler; use schemars::JsonSchema; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; +use crate::application::CallGraphQuery; use crate::connector::api::Container; use crate::domain::SearchQuery; @@ -92,6 +95,65 @@ pub struct ContextToolInput { pub regex: bool, } +/// Relationship pattern for the query_graph tool. +#[derive(Debug, Clone, Copy, Deserialize, Serialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum QueryPattern { + CallersOf, + CalleesOf, + ImportsOf, + ImportersOf, + InheritorsOf, + ChildrenOf, + TestsFor, + FileSummary, +} + +/// Input parameters for the query_graph tool +#[derive(Debug, Deserialize, JsonSchema)] +pub struct QueryGraphInput { + /// Relationship pattern to query. + pub pattern: QueryPattern, + + /// Symbol name or file path (for file_summary) to query. + /// Resolved with the same substring-match fallback as analyze_impact. + pub target: String, + + /// Restrict results to a specific repository ID. + pub repository_id: Option, + + /// Maximum number of unique nodes to return. Omit to return all results. + pub limit: Option, +} + +/// A single deduplicated graph node returned by query_graph +#[derive(Debug, Serialize)] +pub struct GraphQueryNode { + /// The symbol name (caller or callee depending on pattern) + pub symbol: String, + /// File path where the reference occurs + pub file_path: String, + /// Line number where the reference occurs + pub line: u32, + /// The kind of relationship (e.g. "call", "import", "inheritance") + pub reference_kind: String, + /// Repository the node belongs to + pub repository_id: String, +} + +/// Result returned by the query_graph tool +#[derive(Debug, Serialize)] +pub struct GraphQueryResult { + /// The pattern that was queried + pub pattern: QueryPattern, + /// The target symbol or file that was queried + pub target: String, + /// Deduplicated nodes matching the query + pub nodes: Vec, + /// Total number of nodes returned (after deduplication; equals len(nodes)) + pub total: usize, +} + // ── MCP Server ─────────────────────────────────────────────────────────────── /// MCP Server that exposes codesearch functionality @@ -215,6 +277,242 @@ impl CodesearchMcpServer { Ok(CallToolResult::success(vec![Content::text(json)])) } + + /// Query the call graph using an intention-named relationship pattern. + /// Returns deduplicated graph nodes for exactly the relationship type requested, + /// avoiding the noise of receiving all relationship kinds at once. + /// + /// Supported patterns: + /// • callers_of — who calls this symbol + /// • callees_of — what this symbol calls + /// • imports_of — what this symbol imports (Import edges only) + /// • importers_of — who imports this symbol (Import edges only) + /// • inheritors_of — who inherits from / implements this symbol + /// • children_of — what this symbol inherits from / implements + /// • tests_for — test functions or files that exercise this symbol + /// • file_summary — all symbols referenced within a file + /// + /// Requires the repository to have been indexed with call-graph support. + #[tool(name = "query_graph")] + async fn query_graph( + &self, + params: Parameters, + ) -> Result { + let input = params.0; + + let use_case = self.container.call_graph_use_case(); + + let mut base_query = CallGraphQuery::new(); + if let Some(repo_id) = &input.repository_id { + base_query = base_query.with_repository(repo_id.clone()); + } + if let Some(limit) = input.limit { + base_query = base_query.with_limit(limit as u32); + } + + // Each arm returns (references, use_caller). + // use_caller=true → node.symbol = caller_symbol (who performs the action) + // use_caller=false → node.symbol = callee_symbol (what is acted upon) + let (references, use_caller) = match input.pattern { + QueryPattern::CallersOf => { + let refs = use_case + .find_callers(&input.target, &base_query) + .await + .map_err(|e| { + McpError::internal_error(format!("query_graph failed: {}", e), None) + })?; + (refs, true) + } + QueryPattern::CalleesOf => { + let refs = use_case + .find_callees(&input.target, &base_query) + .await + .map_err(|e| { + McpError::internal_error(format!("query_graph failed: {}", e), None) + })?; + (refs, false) + } + QueryPattern::ImportsOf => { + let q = base_query.with_reference_kind("import"); + let refs = use_case + .find_callees(&input.target, &q) + .await + .map_err(|e| { + McpError::internal_error(format!("query_graph failed: {}", e), None) + })?; + (refs, false) + } + QueryPattern::ImportersOf => { + let q = base_query.with_reference_kind("import"); + let refs = use_case + .find_callers(&input.target, &q) + .await + .map_err(|e| { + McpError::internal_error(format!("query_graph failed: {}", e), None) + })?; + (refs, true) + } + QueryPattern::InheritorsOf => { + // Halve the per-query limit so the combined result stays within the + // requested bound before deduplication. + let per_limit = input.limit.map(|n| ((n + 1) / 2) as u32); + let q_inh = { + let q = base_query.clone().with_reference_kind("inheritance"); + match per_limit { + Some(pl) => q.with_limit(pl), + None => q, + } + }; + let q_imp = { + let q = base_query.clone().with_reference_kind("implementation"); + match per_limit { + Some(pl) => q.with_limit(pl), + None => q, + } + }; + let mut refs = use_case + .find_callers(&input.target, &q_inh) + .await + .map_err(|e| { + McpError::internal_error(format!("query_graph failed: {}", e), None) + })?; + let mut refs2 = use_case + .find_callers(&input.target, &q_imp) + .await + .map_err(|e| { + McpError::internal_error(format!("query_graph failed: {}", e), None) + })?; + refs.append(&mut refs2); + (refs, true) + } + QueryPattern::ChildrenOf => { + let per_limit = input.limit.map(|n| ((n + 1) / 2) as u32); + let q_inh = { + let q = base_query.clone().with_reference_kind("inheritance"); + match per_limit { + Some(pl) => q.with_limit(pl), + None => q, + } + }; + let q_imp = { + let q = base_query.clone().with_reference_kind("implementation"); + match per_limit { + Some(pl) => q.with_limit(pl), + None => q, + } + }; + let mut refs = use_case + .find_callees(&input.target, &q_inh) + .await + .map_err(|e| { + McpError::internal_error(format!("query_graph failed: {}", e), None) + })?; + let mut refs2 = use_case + .find_callees(&input.target, &q_imp) + .await + .map_err(|e| { + McpError::internal_error(format!("query_graph failed: {}", e), None) + })?; + refs.append(&mut refs2); + (refs, false) + } + QueryPattern::TestsFor => { + let refs = use_case + .find_callers(&input.target, &base_query) + .await + .map_err(|e| { + McpError::internal_error(format!("query_graph failed: {}", e), None) + })?; + let filtered: Vec<_> = refs + .into_iter() + .filter(|r| { + // Symbol-name heuristics (language-agnostic conventions). + let sym = r.caller_symbol().unwrap_or("").to_lowercase(); + if sym.starts_with("test_") + || sym.ends_with("_test") + || sym.ends_with("_spec") + { + return true; + } + // Path heuristics: inspect components and file stem rather than + // doing a raw substring match to avoid false positives like + // "contest.rs" or "inspect.rs". + let path = Path::new(r.reference_file_path()); + let test_dir = path.components().any(|c| { + if let std::path::Component::Normal(s) = c { + let s = s.to_string_lossy().to_lowercase(); + matches!(s.as_str(), "test" | "tests" | "spec" | "specs") + } else { + false + } + }); + if test_dir { + return true; + } + path.file_stem() + .map(|s| { + let s = s.to_string_lossy().to_lowercase(); + s == "test" + || s.starts_with("test_") + || s.ends_with("_test") + || s.ends_with("_spec") + }) + .unwrap_or(false) + }) + .collect(); + (filtered, true) + } + QueryPattern::FileSummary => { + let refs = use_case + .find_by_file(&input.target, &base_query) + .await + .map_err(|e| { + McpError::internal_error(format!("query_graph failed: {}", e), None) + })?; + (refs, false) + } + }; + + // Deduplicate by symbol name, keeping the first reference site per unique symbol. + // When use_caller is true, entries without a caller_symbol are dropped — a file + // path is not a valid symbol and must not appear in GraphQueryNode.symbol. + let mut seen: HashSet = HashSet::new(); + let deduped = references.into_iter().filter_map(|r| { + let symbol = if use_caller { + r.caller_symbol()?.to_string() + } else { + r.callee_symbol().to_string() + }; + if symbol.is_empty() || !seen.insert(symbol.clone()) { + return None; + } + Some(GraphQueryNode { + symbol, + file_path: r.reference_file_path().to_string(), + line: r.reference_line(), + reference_kind: r.reference_kind().as_str().to_string(), + repository_id: r.repository_id().to_string(), + }) + }); + let nodes: Vec = match input.limit { + Some(n) => deduped.take(n).collect(), + None => deduped.collect(), + }; + + let total = nodes.len(); + let result = GraphQueryResult { + pattern: input.pattern, + target: input.target, + nodes, + total, + }; + + let json = serde_json::to_string_pretty(&result).map_err(|e| { + McpError::internal_error(format!("Failed to serialize result: {}", e), None) + })?; + + Ok(CallToolResult::success(vec![Content::text(json)])) + } } #[tool_handler] @@ -229,7 +527,9 @@ impl ServerHandler for CodesearchMcpServer { • search_code — find code by natural language description (set text_search=false \ to disable keyword+semantic fusion)\n\ • analyze_impact — blast-radius analysis: what breaks if symbol X changes?\n\ - • get_symbol_context — 360° view of a symbol's callers and callees" + • get_symbol_context — 360° view of a symbol's callers and callees\n\ + • query_graph — precise relationship queries: callers_of, callees_of, \ + imports_of, importers_of, inheritors_of, children_of, tests_for, file_summary" .into(), ), }