From 81f7ba318d1b4ac484a0f90280ba95fe5218c32a Mon Sep 17 00:00:00 2001 From: Hunter Haugen Date: Fri, 12 Jun 2026 12:32:53 -0700 Subject: [PATCH] remove root MCP aggregate endpoint --- internal/api/handlers.go | 41 ++++++++++---------------------- internal/api/handlers_test.go | 34 ++++++++++++++++++++++----- internal/invocation/service.go | 43 ---------------------------------- scripts/fake_agent.py | 16 ++++--------- 4 files changed, 46 insertions(+), 88 deletions(-) diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 37c968e..d3ca311 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -38,8 +38,6 @@ var webFS embed.FS type service interface { Invoke(ctx context.Context, req invocation.CreateInvocationRequest) (invocation.InvocationResponse, error) ListTools(ctx context.Context, server string) ([]mcp.Tool, error) - ListAllTools(ctx context.Context) ([]mcp.Tool, error) - ResolveToolServer(ctx context.Context, toolName string) (string, error) Get(ctx context.Context, id string) (invocation.InvocationResponse, error) List(ctx context.Context, filter invocation.InvocationListFilter) (invocation.InvocationListResponse, error) ListAgentIDs(ctx context.Context) ([]string, error) @@ -522,6 +520,7 @@ func (h *Handler) Routes() http.Handler { } mcpHandler := auth.MiddlewareWithOptions(h.authValidator, "/.well-known/oauth-protected-resource", auth.MiddlewareOptions{SkipVerify: h.authDebugSkip, DebugLogIdentity: h.debug})(http.HandlerFunc(h.invokeUpstream)) mcpHandler = h.noAuthAgentIDHint(mcpHandler) + mux.HandleFunc("/mcp", h.mcpRootNotFound) mux.Handle("/mcp/", mcpHandler) mux.HandleFunc("/api/v1/invocations", h.invocations) mux.HandleFunc("/api/v1/admin/invocations", h.adminInvocations) @@ -614,6 +613,10 @@ func (h *Handler) root(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/ui/", http.StatusFound) } +func (h *Handler) mcpRootNotFound(w http.ResponseWriter, _ *http.Request) { + writeError(w, http.StatusNotFound, "MCP server name is required; use /mcp/{server}") +} + func (h *Handler) uiIndex(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/ui" { writeError(w, http.StatusNotFound, "not found") @@ -643,6 +646,10 @@ func (h *Handler) spaFileServer() http.Handler { func (h *Handler) invokeUpstream(w http.ResponseWriter, r *http.Request) { server := strings.TrimPrefix(r.URL.Path, "/mcp/") server = strings.Trim(server, "/") + if server == "" { + writeError(w, http.StatusNotFound, "MCP server name is required; use /mcp/{server}") + return + } // GET: open an SSE keepalive stream (Streamable HTTP transport and legacy SSE clients both try GET) if r.Method == http.MethodGet { @@ -658,14 +665,9 @@ func (h *Handler) invokeUpstream(w http.ResponseWriter, r *http.Request) { return } if isJSONRPCRequest(r) { - // server may be "" — handleMCPProxy handles the aggregate (no-server) case h.handleMCPProxy(w, r, server) return } - if server == "" { - writeError(w, http.StatusNotFound, "server not found") - return - } h.handleInvocation(w, r, server) } @@ -916,13 +918,7 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server case "notifications/initialized": w.WriteHeader(http.StatusAccepted) case "tools/list": - var tools []mcp.Tool - var err error - if server == "" { - tools, err = h.svc.ListAllTools(r.Context()) - } else { - tools, err = h.svc.ListTools(r.Context(), server) - } + tools, err := h.svc.ListTools(r.Context(), server) if err != nil { h.writeRPCError(w, req.ID, -32000, err.Error()) return @@ -950,16 +946,7 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server h.writeRPCResult(w, req.ID, result) return } - callServer := server - if callServer == "" { - resolved, err := h.svc.ResolveToolServer(r.Context(), params.Name) - if err != nil { - h.writeRPCError(w, req.ID, -32000, err.Error()) - return - } - callServer = resolved - } - toolReq := invocation.CreateInvocationRequest{Server: callServer, Tool: params.Name, Input: params.Arguments} + toolReq := invocation.CreateInvocationRequest{Server: server, Tool: params.Name, Input: params.Arguments} if requestID != "" { toolReq.RequestID = stringPtr(requestID) } @@ -980,7 +967,7 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server if len(resp.Error) > 0 { result := normalizeToolCallResult(resp.Error, true) if resp.Status == invocation.StatusDenied { - result = h.appendRulesContextToToolResult(r.Context(), result, callServer, params.Name) + result = h.appendRulesContextToToolResult(r.Context(), result, server, params.Name) } h.writeRPCResult(w, req.ID, result) return @@ -1221,9 +1208,7 @@ type atryumToolPolicy struct { // annotateToolsWithPolicy decorates each tool with its effective approval // disposition for the current agent so the model sees the policy at the moment -// it picks a tool. Annotation requires both rulesRepo and a concrete server; -// in aggregate mode (server == "") we cannot reliably attribute tools to a -// server, so we return the tools unchanged. +// it picks a tool. Annotation requires both rulesRepo and a concrete server. func (h *Handler) annotateToolsWithPolicy(ctx context.Context, server string, tools []mcp.Tool) []any { out := make([]any, len(tools)) if h.rulesRepo == nil || strings.TrimSpace(server) == "" { diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go index 2519704..f86f401 100644 --- a/internal/api/handlers_test.go +++ b/internal/api/handlers_test.go @@ -45,12 +45,6 @@ func (s *stubService) Invoke(ctx context.Context, req invocation.CreateInvocatio func (s *stubService) ListTools(context.Context, string) ([]mcp.Tool, error) { return s.tools, s.listErr } -func (s *stubService) ListAllTools(context.Context) ([]mcp.Tool, error) { - return s.tools, s.listErr -} -func (s *stubService) ResolveToolServer(_ context.Context, _ string) (string, error) { - return s.upstream.Name, nil -} func (s *stubService) Get(_ context.Context, _ string) (invocation.InvocationResponse, error) { return s.invoke, s.getErr } @@ -712,6 +706,34 @@ func TestMCPDeleteReturn405(t *testing.T) { } } +func TestMCPRootRequiresServer(t *testing.T) { + h := NewHandler(&stubService{tools: []mcp.Tool{{Name: "demo_tool"}}}, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil) + tests := []struct { + name string + method string + path string + body string + }{ + {name: "bare path", method: http.MethodPost, path: "/mcp", body: `{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`}, + {name: "trailing slash", method: http.MethodPost, path: "/mcp/", body: `{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`}, + {name: "root sse", method: http.MethodGet, path: "/mcp/"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, tt.path, strings.NewReader(tt.body)) + w := httptest.NewRecorder() + h.Routes().ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Fatalf("%s %s expected 404, got %d body=%s", tt.method, tt.path, w.Code, w.Body.String()) + } + if strings.Contains(w.Body.String(), "demo_tool") { + t.Fatalf("root MCP endpoint should not list tools, got %s", w.Body.String()) + } + }) + } +} + func TestMCPToolsList(t *testing.T) { h := NewHandler(&stubService{tools: []mcp.Tool{{Name: "demo_tool"}}}, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil) req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`)) diff --git a/internal/invocation/service.go b/internal/invocation/service.go index 1d684e3..c2335a8 100644 --- a/internal/invocation/service.go +++ b/internal/invocation/service.go @@ -1075,49 +1075,6 @@ func (s *Service) ListTools(ctx context.Context, server string) ([]mcp.Tool, err return tools, nil } -// ResolveToolServer finds which upstream server provides the named tool. -// Used in aggregate mode (no server in URL) to route tools/call correctly. -func (s *Service) ResolveToolServer(ctx context.Context, toolName string) (string, error) { - upstreams, err := s.resolver.ListAll(ctx) - if err != nil { - return "", err - } - for _, upstream := range upstreams { - tctx, cancel := context.WithTimeout(ctx, s.defaultTimeout) - tools, err := s.client.ListTools(tctx, upstream) - cancel() - if err != nil { - continue - } - for _, t := range tools { - if t.Name == toolName { - return upstream.Name, nil - } - } - } - return "", fmt.Errorf("no server found for tool %q", toolName) -} - -// ListAllTools aggregates tools from every enabled upstream. Used when the MCP -// client connects to the root /mcp endpoint without specifying a server name. -func (s *Service) ListAllTools(ctx context.Context) ([]mcp.Tool, error) { - upstreams, err := s.resolver.ListAll(ctx) - if err != nil { - return nil, err - } - var all []mcp.Tool - for _, upstream := range upstreams { - tctx, cancel := context.WithTimeout(ctx, s.defaultTimeout) - tools, err := s.client.ListTools(tctx, upstream) - cancel() - if err != nil { - continue // skip unreachable servers rather than failing the whole list - } - all = append(all, tools...) - } - return all, nil -} - func (s *Service) Get(ctx context.Context, id string) (InvocationResponse, error) { inv, err := s.invocations.Get(ctx, id) if err != nil { diff --git a/scripts/fake_agent.py b/scripts/fake_agent.py index a7e1028..7daa7c8 100644 --- a/scripts/fake_agent.py +++ b/scripts/fake_agent.py @@ -40,9 +40,6 @@ # Pretend to be a specific harness: python fake_agent.py mcp --client-name cursor --client-version 0.45.7 --list-tools - # Aggregate /mcp/ endpoint (every server's tools merged): - python fake_agent.py mcp '' --list-tools - Config (env or flags): ATRYUM_URL base url, default http://localhost:8080 ATRYUM_MCP_SERVER default MCP server name, default "calc-mcp" @@ -415,7 +412,7 @@ def run_mcp( if bearer: headers["Authorization"] = f"Bearer {bearer}" - print(f"mcp: server={server or '(aggregate)'} client={client_name}/{client_version}") + print(f"mcp: server={server} client={client_name}/{client_version}") # 1. initialize init = mcp_call( @@ -519,10 +516,7 @@ def main(argv: list[str] | None = None) -> int: "server", nargs="?", default=DEFAULT_MCP_SERVER, - help=( - "MCP server name to talk to " - f"(default {DEFAULT_MCP_SERVER!r}; pass empty string '' for aggregate /mcp/)" - ), + help=f"MCP server name to talk to (default {DEFAULT_MCP_SERVER!r})", ) pm.add_argument("--tool", default=None, help="tool name for tools/call") pm.add_argument( @@ -570,11 +564,11 @@ def main(argv: list[str] | None = None) -> int: poll_ms=args.poll_ms, ) elif args.mode == "mcp": - # Empty string explicitly opts into the aggregate /mcp/ route. - server = args.server if args.server != "" else None + if args.server == "": + raise SystemExit("mcp server name is required; use /mcp/{server}") run_mcp( base=args.base, - server=server, + server=args.server, tool=args.tool, arguments=_parse_json_arg(args.arguments, "arguments"), list_tools=args.list_tools,