Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 13 additions & 28 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ var webFS embed.FS
type service interface {
Invoke(ctx context.Context, req invocation.CreateInvocationRequest) (invocation.InvocationResponse, error)
ListTools(ctx context.Context, server string) ([]mcp.Tool, error)
ListAllTools(ctx context.Context) ([]mcp.Tool, error)
ResolveToolServer(ctx context.Context, toolName string) (string, error)
Get(ctx context.Context, id string) (invocation.InvocationResponse, error)
List(ctx context.Context, filter invocation.InvocationListFilter) (invocation.InvocationListResponse, error)
ListAgentIDs(ctx context.Context) ([]string, error)
Expand Down Expand Up @@ -522,6 +520,7 @@ func (h *Handler) Routes() http.Handler {
}
mcpHandler := auth.MiddlewareWithOptions(h.authValidator, "/.well-known/oauth-protected-resource", auth.MiddlewareOptions{SkipVerify: h.authDebugSkip, DebugLogIdentity: h.debug})(http.HandlerFunc(h.invokeUpstream))
mcpHandler = h.noAuthAgentIDHint(mcpHandler)
mux.HandleFunc("/mcp", h.mcpRootNotFound)
mux.Handle("/mcp/", mcpHandler)
mux.HandleFunc("/api/v1/invocations", h.invocations)
mux.HandleFunc("/api/v1/admin/invocations", h.adminInvocations)
Expand Down Expand Up @@ -614,6 +613,10 @@ func (h *Handler) root(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/ui/", http.StatusFound)
}

func (h *Handler) mcpRootNotFound(w http.ResponseWriter, _ *http.Request) {
writeError(w, http.StatusNotFound, "MCP server name is required; use /mcp/{server}")
}

func (h *Handler) uiIndex(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/ui" {
writeError(w, http.StatusNotFound, "not found")
Expand Down Expand Up @@ -643,6 +646,10 @@ func (h *Handler) spaFileServer() http.Handler {
func (h *Handler) invokeUpstream(w http.ResponseWriter, r *http.Request) {
server := strings.TrimPrefix(r.URL.Path, "/mcp/")
server = strings.Trim(server, "/")
if server == "" {
writeError(w, http.StatusNotFound, "MCP server name is required; use /mcp/{server}")
return
}

// GET: open an SSE keepalive stream (Streamable HTTP transport and legacy SSE clients both try GET)
if r.Method == http.MethodGet {
Expand All @@ -658,14 +665,9 @@ func (h *Handler) invokeUpstream(w http.ResponseWriter, r *http.Request) {
return
}
if isJSONRPCRequest(r) {
// server may be "" — handleMCPProxy handles the aggregate (no-server) case
h.handleMCPProxy(w, r, server)
return
}
if server == "" {
writeError(w, http.StatusNotFound, "server not found")
return
}
h.handleInvocation(w, r, server)
}

Expand Down Expand Up @@ -916,13 +918,7 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server
case "notifications/initialized":
w.WriteHeader(http.StatusAccepted)
case "tools/list":
var tools []mcp.Tool
var err error
if server == "" {
tools, err = h.svc.ListAllTools(r.Context())
} else {
tools, err = h.svc.ListTools(r.Context(), server)
}
tools, err := h.svc.ListTools(r.Context(), server)
if err != nil {
h.writeRPCError(w, req.ID, -32000, err.Error())
return
Expand Down Expand Up @@ -950,16 +946,7 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server
h.writeRPCResult(w, req.ID, result)
return
}
callServer := server
if callServer == "" {
resolved, err := h.svc.ResolveToolServer(r.Context(), params.Name)
if err != nil {
h.writeRPCError(w, req.ID, -32000, err.Error())
return
}
callServer = resolved
}
toolReq := invocation.CreateInvocationRequest{Server: callServer, Tool: params.Name, Input: params.Arguments}
toolReq := invocation.CreateInvocationRequest{Server: server, Tool: params.Name, Input: params.Arguments}
if requestID != "" {
toolReq.RequestID = stringPtr(requestID)
}
Expand All @@ -980,7 +967,7 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server
if len(resp.Error) > 0 {
result := normalizeToolCallResult(resp.Error, true)
if resp.Status == invocation.StatusDenied {
result = h.appendRulesContextToToolResult(r.Context(), result, callServer, params.Name)
result = h.appendRulesContextToToolResult(r.Context(), result, server, params.Name)
}
h.writeRPCResult(w, req.ID, result)
return
Expand Down Expand Up @@ -1221,9 +1208,7 @@ type atryumToolPolicy struct {

// annotateToolsWithPolicy decorates each tool with its effective approval
// disposition for the current agent so the model sees the policy at the moment
// it picks a tool. Annotation requires both rulesRepo and a concrete server;
// in aggregate mode (server == "") we cannot reliably attribute tools to a
// server, so we return the tools unchanged.
// it picks a tool. Annotation requires both rulesRepo and a concrete server.
func (h *Handler) annotateToolsWithPolicy(ctx context.Context, server string, tools []mcp.Tool) []any {
out := make([]any, len(tools))
if h.rulesRepo == nil || strings.TrimSpace(server) == "" {
Expand Down
34 changes: 28 additions & 6 deletions internal/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,6 @@ func (s *stubService) Invoke(ctx context.Context, req invocation.CreateInvocatio
func (s *stubService) ListTools(context.Context, string) ([]mcp.Tool, error) {
return s.tools, s.listErr
}
func (s *stubService) ListAllTools(context.Context) ([]mcp.Tool, error) {
return s.tools, s.listErr
}
func (s *stubService) ResolveToolServer(_ context.Context, _ string) (string, error) {
return s.upstream.Name, nil
}
func (s *stubService) Get(_ context.Context, _ string) (invocation.InvocationResponse, error) {
return s.invoke, s.getErr
}
Expand Down Expand Up @@ -712,6 +706,34 @@ func TestMCPDeleteReturn405(t *testing.T) {
}
}

func TestMCPRootRequiresServer(t *testing.T) {
h := NewHandler(&stubService{tools: []mcp.Tool{{Name: "demo_tool"}}}, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil)
tests := []struct {
name string
method string
path string
body string
}{
{name: "bare path", method: http.MethodPost, path: "/mcp", body: `{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`},
{name: "trailing slash", method: http.MethodPost, path: "/mcp/", body: `{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`},
{name: "root sse", method: http.MethodGet, path: "/mcp/"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, tt.path, strings.NewReader(tt.body))
w := httptest.NewRecorder()
h.Routes().ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("%s %s expected 404, got %d body=%s", tt.method, tt.path, w.Code, w.Body.String())
}
if strings.Contains(w.Body.String(), "demo_tool") {
t.Fatalf("root MCP endpoint should not list tools, got %s", w.Body.String())
}
})
}
}

func TestMCPToolsList(t *testing.T) {
h := NewHandler(&stubService{tools: []mcp.Tool{{Name: "demo_tool"}}}, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil)
req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`))
Expand Down
43 changes: 0 additions & 43 deletions internal/invocation/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1075,49 +1075,6 @@ func (s *Service) ListTools(ctx context.Context, server string) ([]mcp.Tool, err
return tools, nil
}

// ResolveToolServer finds which upstream server provides the named tool.
// Used in aggregate mode (no server in URL) to route tools/call correctly.
func (s *Service) ResolveToolServer(ctx context.Context, toolName string) (string, error) {
upstreams, err := s.resolver.ListAll(ctx)
if err != nil {
return "", err
}
for _, upstream := range upstreams {
tctx, cancel := context.WithTimeout(ctx, s.defaultTimeout)
tools, err := s.client.ListTools(tctx, upstream)
cancel()
if err != nil {
continue
}
for _, t := range tools {
if t.Name == toolName {
return upstream.Name, nil
}
}
}
return "", fmt.Errorf("no server found for tool %q", toolName)
}

// ListAllTools aggregates tools from every enabled upstream. Used when the MCP
// client connects to the root /mcp endpoint without specifying a server name.
func (s *Service) ListAllTools(ctx context.Context) ([]mcp.Tool, error) {
upstreams, err := s.resolver.ListAll(ctx)
if err != nil {
return nil, err
}
var all []mcp.Tool
for _, upstream := range upstreams {
tctx, cancel := context.WithTimeout(ctx, s.defaultTimeout)
tools, err := s.client.ListTools(tctx, upstream)
cancel()
if err != nil {
continue // skip unreachable servers rather than failing the whole list
}
all = append(all, tools...)
}
return all, nil
}

func (s *Service) Get(ctx context.Context, id string) (InvocationResponse, error) {
inv, err := s.invocations.Get(ctx, id)
if err != nil {
Expand Down
16 changes: 5 additions & 11 deletions scripts/fake_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@
# Pretend to be a specific harness:
python fake_agent.py mcp --client-name cursor --client-version 0.45.7 --list-tools

# Aggregate /mcp/ endpoint (every server's tools merged):
python fake_agent.py mcp '' --list-tools

Config (env or flags):
ATRYUM_URL base url, default http://localhost:8080
ATRYUM_MCP_SERVER default MCP server name, default "calc-mcp"
Expand Down Expand Up @@ -415,7 +412,7 @@ def run_mcp(
if bearer:
headers["Authorization"] = f"Bearer {bearer}"

print(f"mcp: server={server or '(aggregate)'} client={client_name}/{client_version}")
print(f"mcp: server={server} client={client_name}/{client_version}")

# 1. initialize
init = mcp_call(
Expand Down Expand Up @@ -519,10 +516,7 @@ def main(argv: list[str] | None = None) -> int:
"server",
nargs="?",
default=DEFAULT_MCP_SERVER,
help=(
"MCP server name to talk to "
f"(default {DEFAULT_MCP_SERVER!r}; pass empty string '' for aggregate /mcp/)"
),
help=f"MCP server name to talk to (default {DEFAULT_MCP_SERVER!r})",
)
pm.add_argument("--tool", default=None, help="tool name for tools/call")
pm.add_argument(
Expand Down Expand Up @@ -570,11 +564,11 @@ def main(argv: list[str] | None = None) -> int:
poll_ms=args.poll_ms,
)
elif args.mode == "mcp":
# Empty string explicitly opts into the aggregate /mcp/ route.
server = args.server if args.server != "" else None
if args.server == "":
raise SystemExit("mcp server name is required; use /mcp/{server}")
run_mcp(
base=args.base,
server=server,
server=args.server,
tool=args.tool,
arguments=_parse_json_arg(args.arguments, "arguments"),
list_tools=args.list_tools,
Expand Down
Loading