From f8e7f2380d3cb4ab60508a7eb6b6becd68aae353 Mon Sep 17 00:00:00 2001 From: Hunter Haugen Date: Thu, 11 Jun 2026 12:42:12 -0700 Subject: [PATCH] Refactor MCP SSE response boundaries --- internal/api/handlers.go | 64 +++++++- internal/api/handlers_test.go | 122 +++++++++++++++ internal/mcp/client.go | 183 ++++++++++++++++++++--- internal/mcp/client_test.go | 273 ++++++++++++++++++++++++++++++++++ 4 files changed, 614 insertions(+), 28 deletions(-) diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 0ce7ee1..cb447ec 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -11,6 +11,7 @@ import ( "io" "io/fs" "log" + "mime" "net/http" "os" "sort" @@ -988,17 +989,26 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server if forwarded.ProtocolVersion != "" { setMCPProtocolVersionHeader(w, forwarded.ProtocolVersion) } - if forwarded.ContentType != "" { - w.Header().Set("Content-Type", forwarded.ContentType) - } else { - w.Header().Set("Content-Type", "application/json") + if req.IsNotification() { + w.WriteHeader(http.StatusAccepted) + return } + body, contentType, err := responseBodyForDownstream(r, req.ID, forwarded) + if err != nil { + status := forwarded.StatusCode + if status == 0 { + status = http.StatusOK + } + h.writeRPCErrorStatus(w, status, req.ID, -32000, err.Error()) + return + } + w.Header().Set("Content-Type", contentType) status := forwarded.StatusCode if status == 0 { status = http.StatusOK } w.WriteHeader(status) - _, _ = w.Write(forwarded.Body) + _, _ = w.Write(body) } } @@ -3319,6 +3329,44 @@ func (h *Handler) forwardProxyEnvelope(ctx context.Context, server string, envel return result, true } +func responseBodyForDownstream(r *http.Request, id json.RawMessage, forwarded mcp.ForwardResult) ([]byte, string, error) { + contentType := forwarded.ContentType + if contentType == "" { + contentType = "application/json" + } + if !strings.Contains(strings.ToLower(contentType), "text/event-stream") { + return forwarded.Body, contentType, nil + } + if acceptsEventStream(r) { + return forwarded.Body, contentType, nil + } + body, err := mcp.DecodeJSONRPCPayload(forwarded, id) + if err != nil { + return nil, "", err + } + return body, "application/json", nil +} + +func acceptsEventStream(r *http.Request) bool { + for _, value := range strings.Split(r.Header.Get("Accept"), ",") { + mediaType, params, err := mime.ParseMediaType(strings.TrimSpace(value)) + if err != nil { + continue + } + if !strings.EqualFold(mediaType, "text/event-stream") { + continue + } + if q, ok := params["q"]; ok { + weight, err := strconv.ParseFloat(q, 64) + if err != nil || weight <= 0 { + continue + } + } + return true + } + return false +} + func negotiateProtocolVersion(method, header string, params json.RawMessage) string { var payload struct { ProtocolVersion string `json:"protocolVersion"` @@ -3369,8 +3417,12 @@ func setMCPProtocolVersionHeader(w http.ResponseWriter, version string) { } func (h *Handler) writeRPCError(w http.ResponseWriter, id json.RawMessage, code int, message string) { + h.writeRPCErrorStatus(w, http.StatusOK, id, code, message) +} + +func (h *Handler) writeRPCErrorStatus(w http.ResponseWriter, status int, id json.RawMessage, code int, message string) { errBody, _ := json.Marshal(map[string]any{"code": code, "message": message}) - writeJSON(w, http.StatusOK, jsonRPCResponse{JSONRPC: "2.0", ID: id, Error: errBody}) + writeJSON(w, status, jsonRPCResponse{JSONRPC: "2.0", ID: id, Error: errBody}) } func normalizeToolCallResult(raw json.RawMessage, isError bool) any { diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go index 8db774f..d521452 100644 --- a/internal/api/handlers_test.go +++ b/internal/api/handlers_test.go @@ -512,6 +512,128 @@ func TestMCPPingPassThrough(t *testing.T) { } } +func TestMCPPingCollapsesUpstreamSSEWhenDownstreamDoesNotAcceptEventStream(t *testing.T) { + sseBody := []byte("event: message\n" + + "data: {\"jsonrpc\":\"2.0\",\"method\":\"notifications/progress\",\"params\":{\"progress\":0.5}}\n\n" + + "event: message\n" + + "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"ok\":true}}\n\n") + svc := &stubService{ + upstream: mcp.Upstream{Name: "demo", Mode: mcp.UpstreamModeHTTP}, + forward: mcp.ForwardResult{StatusCode: http.StatusOK, Body: sseBody, ContentType: "text/event-stream", ProtocolVersion: "2025-11-25"}, + } + h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil) + req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"ping","params":{}}`)) + req.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + + h.Routes().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if ct := w.Header().Get("Content-Type"); !strings.Contains(ct, "application/json") { + t.Fatalf("expected application/json, got %q", ct) + } + if strings.Contains(w.Body.String(), "event: message") { + t.Fatalf("expected collapsed JSON body, got %q", w.Body.String()) + } + if !strings.Contains(w.Body.String(), `"ok":true`) { + t.Fatalf("expected forwarded JSON-RPC response, got %s", w.Body.String()) + } +} + +func TestMCPPingPreservesUpstreamSSEWhenDownstreamAcceptsEventStream(t *testing.T) { + sseBody := []byte("event: message\n" + + "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"ok\":true}}\n\n") + svc := &stubService{ + upstream: mcp.Upstream{Name: "demo", Mode: mcp.UpstreamModeHTTP}, + forward: mcp.ForwardResult{StatusCode: http.StatusOK, Body: sseBody, ContentType: "text/event-stream", ProtocolVersion: "2025-11-25"}, + } + h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil) + req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"ping","params":{}}`)) + req.Header.Set("Accept", "text/event-stream") + w := httptest.NewRecorder() + + h.Routes().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if ct := w.Header().Get("Content-Type"); !strings.Contains(ct, "text/event-stream") { + t.Fatalf("expected text/event-stream, got %q", ct) + } + if !strings.Contains(w.Body.String(), "event: message") { + t.Fatalf("expected raw SSE body, got %q", w.Body.String()) + } +} + +func TestMCPPingCollapsesUpstreamSSEWhenEventStreamHasZeroQuality(t *testing.T) { + sseBody := []byte("event: message\n" + + "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"ok\":true}}\n\n") + svc := &stubService{ + upstream: mcp.Upstream{Name: "demo", Mode: mcp.UpstreamModeHTTP}, + forward: mcp.ForwardResult{StatusCode: http.StatusOK, Body: sseBody, ContentType: "text/event-stream", ProtocolVersion: "2025-11-25"}, + } + h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil) + req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"ping","params":{}}`)) + req.Header.Set("Accept", "application/json, text/event-stream;q=0") + w := httptest.NewRecorder() + + h.Routes().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if strings.Contains(w.Body.String(), "event: message") { + t.Fatalf("expected collapsed JSON body, got %q", w.Body.String()) + } + if !strings.Contains(w.Body.String(), `"ok":true`) { + t.Fatalf("expected forwarded JSON-RPC response, got %s", w.Body.String()) + } +} + +func TestMCPPingSSEDecodeErrorPreservesUpstreamStatus(t *testing.T) { + svc := &stubService{ + upstream: mcp.Upstream{Name: "demo", Mode: mcp.UpstreamModeHTTP}, + forward: mcp.ForwardResult{StatusCode: http.StatusBadGateway, Body: []byte(": keepalive\n\n"), ContentType: "text/event-stream", ProtocolVersion: "2025-11-25"}, + } + h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil) + req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"ping","params":{}}`)) + req.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + + h.Routes().ServeHTTP(w, req) + + if w.Code != http.StatusBadGateway { + t.Fatalf("expected 502, got %d", w.Code) + } + if !strings.Contains(w.Body.String(), "no JSON-RPC response in SSE stream") { + t.Fatalf("expected decode error body, got %s", w.Body.String()) + } +} + +func TestMCPForwardedNotificationReturnsAcceptedWithoutSSEBody(t *testing.T) { + sseBody := []byte("event: message\n" + + "data: {\"jsonrpc\":\"2.0\",\"id\":5,\"result\":{\"unrelated\":true}}\n\n") + svc := &stubService{ + upstream: mcp.Upstream{Name: "demo", Mode: mcp.UpstreamModeHTTP}, + forward: mcp.ForwardResult{StatusCode: http.StatusOK, Body: sseBody, ContentType: "text/event-stream", ProtocolVersion: "2025-11-25"}, + } + h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil) + req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","method":"notifications/custom","params":{"x":1}}`)) + req.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + + h.Routes().ServeHTTP(w, req) + + if w.Code != http.StatusAccepted { + t.Fatalf("expected 202, got %d", w.Code) + } + if body := strings.TrimSpace(w.Body.String()); body != "" { + t.Fatalf("expected empty notification response body, got %q", body) + } +} + func TestMCPUnknownNotificationPassThroughFallbackAccepted(t *testing.T) { svc := &stubService{fwdErr: context.Canceled} h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 908542b..a8ad2a3 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -14,6 +14,7 @@ import ( "os" "os/exec" "sort" + "strconv" "strings" "sync" "sync/atomic" @@ -532,8 +533,8 @@ func (c *Client) invokeHTTP(ctx context.Context, upstream Upstream, tool string, if err != nil { return InvokeResult{}, err } - var rpcResp rpcResponse - if err := json.Unmarshal(result.Body, &rpcResp); err != nil { + rpcResp, err := decodeRPCResponse(result, json.RawMessage([]byte("1"))) + if err != nil { return InvokeResult{}, err } if len(rpcResp.Error) > 0 && string(rpcResp.Error) != "null" { @@ -545,8 +546,8 @@ func (c *Client) invokeHTTP(ctx context.Context, upstream Upstream, tool string, if err != nil { return InvokeResult{}, err } - rpcResp = rpcResponse{} - if err := json.Unmarshal(result.Body, &rpcResp); err != nil { + rpcResp, err = decodeRPCResponse(result, json.RawMessage([]byte("1"))) + if err != nil { return InvokeResult{}, err } } @@ -575,8 +576,8 @@ func (c *Client) listToolsHTTP(ctx context.Context, upstream Upstream) ([]Tool, if err != nil { return nil, err } - var rpcResp rpcResponse - if err := json.Unmarshal(result.Body, &rpcResp); err != nil { + rpcResp, err := decodeRPCResponse(result, json.RawMessage([]byte("1"))) + if err != nil { return nil, err } if len(rpcResp.Error) > 0 && string(rpcResp.Error) != "null" { @@ -588,8 +589,8 @@ func (c *Client) listToolsHTTP(ctx context.Context, upstream Upstream) ([]Tool, if err != nil { return nil, err } - rpcResp = rpcResponse{} - if err := json.Unmarshal(result.Body, &rpcResp); err != nil { + rpcResp, err = decodeRPCResponse(result, json.RawMessage([]byte("1"))) + if err != nil { return nil, err } } @@ -664,13 +665,6 @@ func (c *Client) doHTTPEnvelope(ctx context.Context, upstream Upstream, body []b bodyBytes, _ := io.ReadAll(resp.Body) return ForwardResult{StatusCode: resp.StatusCode, Body: bodyBytes, ContentType: contentType, ProtocolVersion: resp.Header.Get("MCP-Protocol-Version"), SessionExpired: true, SessionID: sessionID}, nil } - if strings.Contains(contentType, "text/event-stream") { - data, sseErr := extractFirstSSEData(resp.Body) - if sseErr != nil { - return ForwardResult{}, sseErr - } - return ForwardResult{StatusCode: resp.StatusCode, Body: data, ContentType: "application/json", ProtocolVersion: resp.Header.Get("MCP-Protocol-Version"), SessionID: sessionID}, nil - } respBody := new(bytes.Buffer) _, err = respBody.ReadFrom(resp.Body) if err != nil { @@ -832,11 +826,16 @@ func (c *Client) initializeHTTPSession(ctx context.Context, upstream Upstream, p c.debugf("connection test http response server=%s status=%d content_type=%q protocol=%q body=%s", upstream.Name, result.StatusCode, result.ContentType, result.ProtocolVersion, truncateForLog(result.Body, 600)) if result.StatusCode >= http.StatusBadRequest { c.clearSession(upstream.Name) - return false, fmt.Errorf("upstream initialize using MCP %s failed: http %d: %s", protocolVersion, result.StatusCode, extractErrorDetail(result.Body)) + return false, fmt.Errorf("upstream initialize using MCP %s failed: http %d: %s", protocolVersion, result.StatusCode, extractForwardResultErrorDetail(result, json.RawMessage([]byte("1")))) } - if len(bytes.TrimSpace(result.Body)) > 0 { + resultBody, err := DecodeJSONRPCPayload(result, json.RawMessage([]byte("1"))) + if err != nil { + c.clearSession(upstream.Name) + return false, fmt.Errorf("upstream initialize using MCP %s returned invalid JSON-RPC: %w", protocolVersion, err) + } + if len(bytes.TrimSpace(resultBody)) > 0 { var rpcResp rpcResponse - if err := json.Unmarshal(result.Body, &rpcResp); err != nil { + if err := json.Unmarshal(resultBody, &rpcResp); err != nil { c.clearSession(upstream.Name) return false, fmt.Errorf("upstream initialize using MCP %s returned invalid JSON-RPC: %w", protocolVersion, err) } @@ -858,6 +857,36 @@ func (c *Client) initializeHTTPSession(ctx context.Context, upstream Upstream, p return hadSession, nil } +func decodeRPCResponse(result ForwardResult, expectedID json.RawMessage) (rpcResponse, error) { + body, err := DecodeJSONRPCPayload(result, expectedID) + if err != nil { + return rpcResponse{}, err + } + var rpcResp rpcResponse + if err := json.Unmarshal(body, &rpcResp); err != nil { + return rpcResponse{}, err + } + return rpcResp, nil +} + +// DecodeJSONRPCPayload returns the JSON-RPC response payload from a raw +// upstream transport result. Plain JSON bodies are already one payload; SSE +// bodies are scanned for the response event matching expectedID, skipping +// notifications and joining multi-line data fields according to the SSE spec. +func DecodeJSONRPCPayload(result ForwardResult, expectedID json.RawMessage) ([]byte, error) { + if strings.Contains(strings.ToLower(result.ContentType), "text/event-stream") { + return extractSSEJSONRPCResponse(bytes.NewReader(result.Body), expectedID) + } + return result.Body, nil +} + +func extractForwardResultErrorDetail(result ForwardResult, expectedID json.RawMessage) string { + if body, err := DecodeJSONRPCPayload(result, expectedID); err == nil && len(bytes.TrimSpace(body)) > 0 { + return extractErrorDetail(body) + } + return extractErrorDetail(result.Body) +} + func (c *Client) invokeStdio(ctx context.Context, upstream Upstream, tool string, input map[string]any) (InvokeResult, error) { if upstream.Command == "" { return InvokeResult{}, fmt.Errorf("stdio upstream %q missing command", upstream.Name) @@ -1177,19 +1206,129 @@ func (c *Client) testStdio(ctx context.Context, upstream Upstream) ConnectionTes return ConnectionTestResult{Ok: true, Message: "stdio initialize ok", ConnectionStatus: ConnectionStatusReady, AuthStatus: AuthStatusReady, ReauthNeeded: false, LastCheckOK: true} } -func extractFirstSSEData(r io.Reader) ([]byte, error) { +func extractSSEJSONRPCResponse(r io.Reader, expectedID json.RawMessage) ([]byte, error) { scanner := bufio.NewScanner(r) scanner.Buffer(make([]byte, 1024*1024), 4*1024*1024) + var dataLines []string + flush := func() ([]byte, bool) { + if len(dataLines) == 0 { + return nil, false + } + payload := []byte(strings.Join(dataLines, "\n")) + dataLines = nil + match := classifyJSONRPCResponsePayload(payload, expectedID) + if match == jsonRPCResponseIDMatch || match == jsonRPCResponseNullIDError { + return payload, true + } + return nil, false + } for scanner.Scan() { line := scanner.Text() - if strings.HasPrefix(line, "data:") { - return []byte(strings.TrimSpace(strings.TrimPrefix(line, "data:"))), nil + if line == "" { + if payload, ok := flush(); ok { + return payload, nil + } + continue + } + if strings.HasPrefix(line, ":") { + continue + } + field, value, ok := strings.Cut(line, ":") + if !ok { + field = line + value = "" + } else if strings.HasPrefix(value, " ") { + value = strings.TrimPrefix(value, " ") + } + if field == "data" { + dataLines = append(dataLines, value) } } if err := scanner.Err(); err != nil { return nil, err } - return nil, fmt.Errorf("no data in SSE stream") + if payload, ok := flush(); ok { + return payload, nil + } + return nil, fmt.Errorf("no JSON-RPC response in SSE stream") +} + +type jsonRPCResponseMatch int + +const ( + jsonRPCResponseNoMatch jsonRPCResponseMatch = iota + jsonRPCResponseIDMatch + jsonRPCResponseNullIDError +) + +func classifyJSONRPCResponsePayload(payload []byte, expectedID json.RawMessage) jsonRPCResponseMatch { + var message map[string]json.RawMessage + if err := json.Unmarshal(payload, &message); err != nil { + return jsonRPCResponseNoMatch + } + id, hasID := message["id"] + if !hasID { + return jsonRPCResponseNoMatch + } + _, hasResult := message["result"] + _, hasError := message["error"] + if !hasResult && !hasError { + return jsonRPCResponseNoMatch + } + if hasError && jsonRawIsNull(id) { + return jsonRPCResponseNullIDError + } + if len(bytes.TrimSpace(expectedID)) == 0 { + return jsonRPCResponseNoMatch + } + if jsonRPCIDsMatch(id, expectedID) { + return jsonRPCResponseIDMatch + } + return jsonRPCResponseNoMatch +} + +func jsonRPCIDsMatch(a, b json.RawMessage) bool { + if jsonRawEqual(a, b) { + return true + } + av, aok := jsonIDComparable(a) + bv, bok := jsonIDComparable(b) + return aok && bok && av == bv +} + +func jsonRawIsNull(raw json.RawMessage) bool { + var compact bytes.Buffer + if err := json.Compact(&compact, raw); err != nil { + return bytes.Equal(bytes.TrimSpace(raw), []byte("null")) + } + return bytes.Equal(compact.Bytes(), []byte("null")) +} + +func jsonIDComparable(raw json.RawMessage) (string, bool) { + var value any + if err := json.Unmarshal(raw, &value); err != nil { + return "", false + } + switch v := value.(type) { + case string: + return v, true + case float64: + return strconv.FormatFloat(v, 'f', -1, 64), true + default: + return "", false + } +} + +func jsonRawEqual(a, b json.RawMessage) bool { + var compactA bytes.Buffer + if err := json.Compact(&compactA, a); err != nil { + return bytes.Equal(bytes.TrimSpace(a), bytes.TrimSpace(b)) + } + var compactB bytes.Buffer + if err := json.Compact(&compactB, b); err != nil { + return bytes.Equal(bytes.TrimSpace(a), bytes.TrimSpace(b)) + } + return bytes.Equal(compactA.Bytes(), compactB.Bytes()) } func extractErrorDetail(body []byte) string { diff --git a/internal/mcp/client_test.go b/internal/mcp/client_test.go index 997c0c0..5093992 100644 --- a/internal/mcp/client_test.go +++ b/internal/mcp/client_test.go @@ -132,6 +132,254 @@ func TestListToolsRetriesInitializeWithCompatibleProtocol(t *testing.T) { } } +func TestListToolsDecodesSSEResponseAfterMissingSessionReinitialize(t *testing.T) { + var toolsListCount int + var sessions []string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Envelope + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + switch req.Method { + case "initialize": + sessionID := "sid-1" + if len(sessions) > 0 { + sessionID = "sid-2" + } + sessions = append(sessions, sessionID) + w.Header().Set("Mcp-Session-Id", sessionID) + writeTestRPC(w, req.ID, map[string]any{"protocolVersion": r.Header.Get("MCP-Protocol-Version"), "capabilities": map[string]any{}}, nil) + case "notifications/initialized": + w.WriteHeader(http.StatusAccepted) + case "tools/list": + toolsListCount++ + if toolsListCount == 1 { + writeTestRPC(w, req.ID, nil, map[string]any{"code": -32000, "message": "No session ID provided for non-initialization request"}) + return + } + if got := r.Header.Get("Mcp-Session-Id"); got != "sid-2" { + t.Fatalf("retry tools/list used session %q, want sid-2", got) + } + writeTestRPCSSE(w, req.ID, map[string]any{"tools": []map[string]any{{"name": "stories.search"}}}, nil) + default: + t.Fatalf("unexpected method %q", req.Method) + } + })) + defer server.Close() + + client := NewHTTPClient() + client.httpClient = server.Client() + tools, err := client.ListTools(context.Background(), Upstream{Name: "shortcut", Mode: UpstreamModeHTTP, BaseURL: server.URL}) + if err != nil { + t.Fatalf("ListTools returned error: %v", err) + } + if len(tools) != 1 || tools[0].Name != "stories.search" { + t.Fatalf("unexpected tools: %#v", tools) + } + if toolsListCount != 2 { + t.Fatalf("tools/list count = %d, want 2", toolsListCount) + } + if len(sessions) != 2 { + t.Fatalf("initialize sessions = %#v, want two sessions", sessions) + } +} + +func TestInitializeDecodesSSEResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Envelope + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + switch req.Method { + case "initialize": + w.Header().Set("Mcp-Session-Id", "sid-sse") + writeTestRPCSSE(w, req.ID, map[string]any{"protocolVersion": r.Header.Get("MCP-Protocol-Version"), "capabilities": map[string]any{}}, nil) + case "notifications/initialized": + if got := r.Header.Get("Mcp-Session-Id"); got != "sid-sse" { + t.Fatalf("initialized notification missing session id: %q", got) + } + w.WriteHeader(http.StatusAccepted) + case "tools/list": + if got := r.Header.Get("Mcp-Session-Id"); got != "sid-sse" { + t.Fatalf("tools/list missing session id: %q", got) + } + writeTestRPC(w, req.ID, map[string]any{"tools": []map[string]any{{"name": "stories.search"}}}, nil) + default: + t.Fatalf("unexpected method %q", req.Method) + } + })) + defer server.Close() + + client := NewHTTPClient() + client.httpClient = server.Client() + tools, err := client.ListTools(context.Background(), Upstream{Name: "shortcut", Mode: UpstreamModeHTTP, BaseURL: server.URL}) + if err != nil { + t.Fatalf("ListTools returned error: %v", err) + } + if len(tools) != 1 || tools[0].Name != "stories.search" { + t.Fatalf("unexpected tools: %#v", tools) + } +} + +func TestInitializeSSEHTTPErrorReportsRPCMessage(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Envelope + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + if req.Method != "initialize" { + t.Fatalf("unexpected method %q", req.Method) + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusUnauthorized) + writeTestRPCSSE(w, req.ID, nil, map[string]any{"code": -32000, "message": "unauthorized"}) + })) + defer server.Close() + + client := NewHTTPClient() + client.httpClient = server.Client() + result := client.TestConnection(context.Background(), Upstream{Name: "shortcut", Mode: UpstreamModeHTTP, BaseURL: server.URL, Enabled: true}) + if result.Ok { + t.Fatalf("expected failed connection test") + } + if !strings.Contains(result.Message, "unauthorized") { + t.Fatalf("expected clean upstream error message, got %q", result.Message) + } + if strings.Contains(result.Message, "data:") { + t.Fatalf("expected SSE framing to be hidden, got %q", result.Message) + } +} + +func TestInitializeSSEHTTPErrorReportsNullIDRPCMessage(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Envelope + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + if req.Method != "initialize" { + t.Fatalf("unexpected method %q", req.Method) + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusBadRequest) + writeTestRPCSSE(w, json.RawMessage("null"), nil, map[string]any{"code": -32700, "message": "parse error"}) + })) + defer server.Close() + + client := NewHTTPClient() + client.httpClient = server.Client() + result := client.TestConnection(context.Background(), Upstream{Name: "shortcut", Mode: UpstreamModeHTTP, BaseURL: server.URL, Enabled: true}) + if result.Ok { + t.Fatalf("expected failed connection test") + } + if !strings.Contains(result.Message, "parse error") { + t.Fatalf("expected clean upstream error message, got %q", result.Message) + } + if strings.Contains(result.Message, "data:") { + t.Fatalf("expected SSE framing to be hidden, got %q", result.Message) + } +} + +func TestListToolsAcceptsStringIDInSSEResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Envelope + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + switch req.Method { + case "initialize": + w.Header().Set("Mcp-Session-Id", "sid-string-id") + writeTestRPC(w, req.ID, map[string]any{"protocolVersion": r.Header.Get("MCP-Protocol-Version"), "capabilities": map[string]any{}}, nil) + case "notifications/initialized": + w.WriteHeader(http.StatusAccepted) + case "tools/list": + writeTestRPCSSE(w, json.RawMessage(`"1"`), map[string]any{"tools": []map[string]any{{"name": "stories.string_id"}}}, nil) + default: + t.Fatalf("unexpected method %q", req.Method) + } + })) + defer server.Close() + + client := NewHTTPClient() + client.httpClient = server.Client() + tools, err := client.ListTools(context.Background(), Upstream{Name: "shortcut", Mode: UpstreamModeHTTP, BaseURL: server.URL}) + if err != nil { + t.Fatalf("ListTools returned error: %v", err) + } + if len(tools) != 1 || tools[0].Name != "stories.string_id" { + t.Fatalf("unexpected tools: %#v", tools) + } +} + +func TestInvokeSkipsSSENotificationBeforeResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Envelope + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + switch req.Method { + case "initialize": + w.Header().Set("Mcp-Session-Id", "sid-notify") + writeTestRPC(w, req.ID, map[string]any{"protocolVersion": r.Header.Get("MCP-Protocol-Version"), "capabilities": map[string]any{}}, nil) + case "notifications/initialized": + w.WriteHeader(http.StatusAccepted) + case "tools/call": + writeTestSSEEvents(w, + []string{`{"jsonrpc":"2.0","method":"notifications/progress","params":{"progress":0.5}}`}, + []string{`{"jsonrpc":"2.0","id":1,"result":{"content":[{"type":"text","text":"actual result"}]}}`}, + ) + default: + t.Fatalf("unexpected method %q", req.Method) + } + })) + defer server.Close() + + client := NewHTTPClient() + client.httpClient = server.Client() + result, err := client.Invoke(context.Background(), Upstream{Name: "shortcut", Mode: UpstreamModeHTTP, BaseURL: server.URL}, "stories.get", map[string]any{}, nil) + if err != nil { + t.Fatalf("Invoke returned error: %v", err) + } + if !strings.Contains(string(result.Body), "actual result") { + t.Fatalf("expected final response body, got %s", string(result.Body)) + } +} + +func TestListToolsDecodesMultilineSSEData(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Envelope + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + switch req.Method { + case "initialize": + w.Header().Set("Mcp-Session-Id", "sid-multiline") + writeTestRPC(w, req.ID, map[string]any{"protocolVersion": r.Header.Get("MCP-Protocol-Version"), "capabilities": map[string]any{}}, nil) + case "notifications/initialized": + w.WriteHeader(http.StatusAccepted) + case "tools/list": + writeTestSSEEvents(w, []string{ + `{"jsonrpc":"2.0",`, + `"id":1,`, + `"result":{"tools":[{"name":"stories.multiline"}]}}`, + }) + default: + t.Fatalf("unexpected method %q", req.Method) + } + })) + defer server.Close() + + client := NewHTTPClient() + client.httpClient = server.Client() + tools, err := client.ListTools(context.Background(), Upstream{Name: "shortcut", Mode: UpstreamModeHTTP, BaseURL: server.URL}) + if err != nil { + t.Fatalf("ListTools returned error: %v", err) + } + if len(tools) != 1 || tools[0].Name != "stories.multiline" { + t.Fatalf("unexpected tools: %#v", tools) + } +} + func TestMissingSessionRPCErrorDetection(t *testing.T) { if !isMissingSessionRPCError(json.RawMessage(`{"code":-32000,"message":"No session ID provided for non-initialization request"}`)) { t.Fatal("expected missing session error to be detected") @@ -417,3 +665,28 @@ func writeTestRPC(w http.ResponseWriter, id json.RawMessage, result any, rpcErr panic(err) } } + +func writeTestRPCSSE(w http.ResponseWriter, id json.RawMessage, result any, rpcErr any) { + resp := map[string]any{"jsonrpc": "2.0", "id": id} + if rpcErr != nil { + resp["error"] = rpcErr + } else { + resp["result"] = result + } + payload, err := json.Marshal(resp) + if err != nil { + panic(err) + } + writeTestSSEEvents(w, []string{string(payload)}) +} + +func writeTestSSEEvents(w http.ResponseWriter, events ...[]string) { + w.Header().Set("Content-Type", "text/event-stream") + for _, lines := range events { + _, _ = w.Write([]byte("event: message\n")) + for _, line := range lines { + _, _ = w.Write([]byte("data: " + line + "\n")) + } + _, _ = w.Write([]byte("\n")) + } +}