Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 302 additions & 2 deletions src/connector/adapter/mcp/server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashSet;
use std::path::Path;
use std::sync::Arc;

use rmcp::handler::server::tool::ToolRouter;
Expand All @@ -11,8 +13,9 @@ use rmcp::tool_router;
use rmcp::ErrorData as McpError;
use rmcp::ServerHandler;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::{Deserialize, Serialize};

use crate::application::CallGraphQuery;
use crate::connector::api::Container;
use crate::domain::SearchQuery;

Expand Down Expand Up @@ -92,6 +95,65 @@ pub struct ContextToolInput {
pub regex: bool,
}

/// Relationship pattern for the query_graph tool.
#[derive(Debug, Clone, Copy, Deserialize, Serialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum QueryPattern {
CallersOf,
CalleesOf,
ImportsOf,
ImportersOf,
InheritorsOf,
ChildrenOf,
TestsFor,
FileSummary,
}

/// Input parameters for the query_graph tool
#[derive(Debug, Deserialize, JsonSchema)]
pub struct QueryGraphInput {
/// Relationship pattern to query.
pub pattern: QueryPattern,

/// Symbol name or file path (for file_summary) to query.
/// Resolved with the same substring-match fallback as analyze_impact.
pub target: String,

/// Restrict results to a specific repository ID.
pub repository_id: Option<String>,

/// Maximum number of unique nodes to return. Omit to return all results.
pub limit: Option<usize>,
}

/// A single deduplicated graph node returned by query_graph
#[derive(Debug, Serialize)]
pub struct GraphQueryNode {
/// The symbol name (caller or callee depending on pattern)
pub symbol: String,
/// File path where the reference occurs
pub file_path: String,
/// Line number where the reference occurs
pub line: u32,
/// The kind of relationship (e.g. "call", "import", "inheritance")
pub reference_kind: String,
/// Repository the node belongs to
pub repository_id: String,
}

/// Result returned by the query_graph tool
#[derive(Debug, Serialize)]
pub struct GraphQueryResult {
/// The pattern that was queried
pub pattern: QueryPattern,
/// The target symbol or file that was queried
pub target: String,
/// Deduplicated nodes matching the query
pub nodes: Vec<GraphQueryNode>,
/// Total number of nodes returned (after deduplication; equals len(nodes))
pub total: usize,
}

// ── MCP Server ───────────────────────────────────────────────────────────────

/// MCP Server that exposes codesearch functionality
Expand Down Expand Up @@ -215,6 +277,242 @@ impl CodesearchMcpServer {

Ok(CallToolResult::success(vec![Content::text(json)]))
}

/// Query the call graph using an intention-named relationship pattern.
/// Returns deduplicated graph nodes for exactly the relationship type requested,
/// avoiding the noise of receiving all relationship kinds at once.
///
/// Supported patterns:
/// • callers_of — who calls this symbol
/// • callees_of — what this symbol calls
/// • imports_of — what this symbol imports (Import edges only)
/// • importers_of — who imports this symbol (Import edges only)
/// • inheritors_of — who inherits from / implements this symbol
/// • children_of — what this symbol inherits from / implements
/// • tests_for — test functions or files that exercise this symbol
/// • file_summary — all symbols referenced within a file
///
/// Requires the repository to have been indexed with call-graph support.
#[tool(name = "query_graph")]
async fn query_graph(
&self,
params: Parameters<QueryGraphInput>,
) -> Result<CallToolResult, McpError> {
let input = params.0;

let use_case = self.container.call_graph_use_case();

let mut base_query = CallGraphQuery::new();
if let Some(repo_id) = &input.repository_id {
base_query = base_query.with_repository(repo_id.clone());
}
if let Some(limit) = input.limit {
base_query = base_query.with_limit(limit as u32);
}

// Each arm returns (references, use_caller).
// use_caller=true → node.symbol = caller_symbol (who performs the action)
// use_caller=false → node.symbol = callee_symbol (what is acted upon)
let (references, use_caller) = match input.pattern {
QueryPattern::CallersOf => {
let refs = use_case
.find_callers(&input.target, &base_query)
.await
.map_err(|e| {
McpError::internal_error(format!("query_graph failed: {}", e), None)
})?;
(refs, true)
}
QueryPattern::CalleesOf => {
let refs = use_case
.find_callees(&input.target, &base_query)
.await
.map_err(|e| {
McpError::internal_error(format!("query_graph failed: {}", e), None)
})?;
(refs, false)
}
QueryPattern::ImportsOf => {
let q = base_query.with_reference_kind("import");
let refs = use_case
.find_callees(&input.target, &q)
.await
.map_err(|e| {
McpError::internal_error(format!("query_graph failed: {}", e), None)
})?;
(refs, false)
}
QueryPattern::ImportersOf => {
let q = base_query.with_reference_kind("import");
let refs = use_case
.find_callers(&input.target, &q)
.await
.map_err(|e| {
McpError::internal_error(format!("query_graph failed: {}", e), None)
})?;
(refs, true)
}
QueryPattern::InheritorsOf => {
// Halve the per-query limit so the combined result stays within the
// requested bound before deduplication.
let per_limit = input.limit.map(|n| ((n + 1) / 2) as u32);
let q_inh = {
let q = base_query.clone().with_reference_kind("inheritance");
match per_limit {
Some(pl) => q.with_limit(pl),
None => q,
}
};
let q_imp = {
let q = base_query.clone().with_reference_kind("implementation");
match per_limit {
Some(pl) => q.with_limit(pl),
None => q,
}
};
let mut refs = use_case
.find_callers(&input.target, &q_inh)
.await
.map_err(|e| {
McpError::internal_error(format!("query_graph failed: {}", e), None)
})?;
let mut refs2 = use_case
.find_callers(&input.target, &q_imp)
.await
.map_err(|e| {
McpError::internal_error(format!("query_graph failed: {}", e), None)
})?;
refs.append(&mut refs2);
(refs, true)
}
QueryPattern::ChildrenOf => {
let per_limit = input.limit.map(|n| ((n + 1) / 2) as u32);
let q_inh = {
let q = base_query.clone().with_reference_kind("inheritance");
match per_limit {
Some(pl) => q.with_limit(pl),
None => q,
}
};
let q_imp = {
let q = base_query.clone().with_reference_kind("implementation");
match per_limit {
Some(pl) => q.with_limit(pl),
None => q,
}
};
let mut refs = use_case
.find_callees(&input.target, &q_inh)
.await
.map_err(|e| {
McpError::internal_error(format!("query_graph failed: {}", e), None)
})?;
let mut refs2 = use_case
.find_callees(&input.target, &q_imp)
.await
.map_err(|e| {
McpError::internal_error(format!("query_graph failed: {}", e), None)
})?;
refs.append(&mut refs2);
(refs, false)
}
QueryPattern::TestsFor => {
let refs = use_case
.find_callers(&input.target, &base_query)
.await
.map_err(|e| {
McpError::internal_error(format!("query_graph failed: {}", e), None)
})?;
let filtered: Vec<_> = refs
.into_iter()
.filter(|r| {
// Symbol-name heuristics (language-agnostic conventions).
let sym = r.caller_symbol().unwrap_or("").to_lowercase();
if sym.starts_with("test_")
|| sym.ends_with("_test")
|| sym.ends_with("_spec")
{
return true;
}
// Path heuristics: inspect components and file stem rather than
// doing a raw substring match to avoid false positives like
// "contest.rs" or "inspect.rs".
let path = Path::new(r.reference_file_path());
let test_dir = path.components().any(|c| {
if let std::path::Component::Normal(s) = c {
let s = s.to_string_lossy().to_lowercase();
matches!(s.as_str(), "test" | "tests" | "spec" | "specs")
} else {
false
}
});
if test_dir {
return true;
}
path.file_stem()
.map(|s| {
let s = s.to_string_lossy().to_lowercase();
s == "test"
|| s.starts_with("test_")
|| s.ends_with("_test")
|| s.ends_with("_spec")
})
.unwrap_or(false)
})
.collect();
(filtered, true)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
QueryPattern::FileSummary => {
let refs = use_case
.find_by_file(&input.target, &base_query)
.await
.map_err(|e| {
McpError::internal_error(format!("query_graph failed: {}", e), None)
})?;
(refs, false)
}
};

// Deduplicate by symbol name, keeping the first reference site per unique symbol.
// When use_caller is true, entries without a caller_symbol are dropped — a file
// path is not a valid symbol and must not appear in GraphQueryNode.symbol.
let mut seen: HashSet<String> = HashSet::new();
let deduped = references.into_iter().filter_map(|r| {
let symbol = if use_caller {
r.caller_symbol()?.to_string()
} else {
r.callee_symbol().to_string()
};
if symbol.is_empty() || !seen.insert(symbol.clone()) {
return None;
}
Some(GraphQueryNode {
symbol,
file_path: r.reference_file_path().to_string(),
line: r.reference_line(),
reference_kind: r.reference_kind().as_str().to_string(),
repository_id: r.repository_id().to_string(),
})
});
let nodes: Vec<GraphQueryNode> = match input.limit {
Some(n) => deduped.take(n).collect(),
None => deduped.collect(),
};

let total = nodes.len();
let result = GraphQueryResult {
pattern: input.pattern,
target: input.target,
nodes,
total,
};

let json = serde_json::to_string_pretty(&result).map_err(|e| {
McpError::internal_error(format!("Failed to serialize result: {}", e), None)
})?;

Ok(CallToolResult::success(vec![Content::text(json)]))
}
}

#[tool_handler]
Expand All @@ -229,7 +527,9 @@ impl ServerHandler for CodesearchMcpServer {
• search_code — find code by natural language description (set text_search=false \
to disable keyword+semantic fusion)\n\
• analyze_impact — blast-radius analysis: what breaks if symbol X changes?\n\
• get_symbol_context — 360° view of a symbol's callers and callees"
• get_symbol_context — 360° view of a symbol's callers and callees\n\
• query_graph — precise relationship queries: callers_of, callees_of, \
imports_of, importers_of, inheritors_of, children_of, tests_for, file_summary"
.into(),
),
}
Expand Down
Loading