Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 58 additions & 6 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io"
"io/fs"
"log"
"mime"
"net/http"
"os"
"sort"
Expand Down Expand Up @@ -988,17 +989,26 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server
if forwarded.ProtocolVersion != "" {
setMCPProtocolVersionHeader(w, forwarded.ProtocolVersion)
}
if forwarded.ContentType != "" {
w.Header().Set("Content-Type", forwarded.ContentType)
} else {
w.Header().Set("Content-Type", "application/json")
if req.IsNotification() {
w.WriteHeader(http.StatusAccepted)
return
}
body, contentType, err := responseBodyForDownstream(r, req.ID, forwarded)
if err != nil {
status := forwarded.StatusCode
if status == 0 {
status = http.StatusOK
}
h.writeRPCErrorStatus(w, status, req.ID, -32000, err.Error())
return
}
w.Header().Set("Content-Type", contentType)
status := forwarded.StatusCode
if status == 0 {
status = http.StatusOK
}
w.WriteHeader(status)
_, _ = w.Write(forwarded.Body)
_, _ = w.Write(body)
}
}

Expand Down Expand Up @@ -3319,6 +3329,44 @@ func (h *Handler) forwardProxyEnvelope(ctx context.Context, server string, envel
return result, true
}

func responseBodyForDownstream(r *http.Request, id json.RawMessage, forwarded mcp.ForwardResult) ([]byte, string, error) {
contentType := forwarded.ContentType
if contentType == "" {
contentType = "application/json"
}
if !strings.Contains(strings.ToLower(contentType), "text/event-stream") {
return forwarded.Body, contentType, nil
}
if acceptsEventStream(r) {
return forwarded.Body, contentType, nil
}
body, err := mcp.DecodeJSONRPCPayload(forwarded, id)
if err != nil {
return nil, "", err
}
return body, "application/json", nil
}

func acceptsEventStream(r *http.Request) bool {
for _, value := range strings.Split(r.Header.Get("Accept"), ",") {
mediaType, params, err := mime.ParseMediaType(strings.TrimSpace(value))
if err != nil {
continue
}
if !strings.EqualFold(mediaType, "text/event-stream") {
continue
}
if q, ok := params["q"]; ok {
weight, err := strconv.ParseFloat(q, 64)
if err != nil || weight <= 0 {
continue
}
}
return true
}
return false
}

func negotiateProtocolVersion(method, header string, params json.RawMessage) string {
var payload struct {
ProtocolVersion string `json:"protocolVersion"`
Expand Down Expand Up @@ -3369,8 +3417,12 @@ func setMCPProtocolVersionHeader(w http.ResponseWriter, version string) {
}

func (h *Handler) writeRPCError(w http.ResponseWriter, id json.RawMessage, code int, message string) {
h.writeRPCErrorStatus(w, http.StatusOK, id, code, message)
}

func (h *Handler) writeRPCErrorStatus(w http.ResponseWriter, status int, id json.RawMessage, code int, message string) {
errBody, _ := json.Marshal(map[string]any{"code": code, "message": message})
writeJSON(w, http.StatusOK, jsonRPCResponse{JSONRPC: "2.0", ID: id, Error: errBody})
writeJSON(w, status, jsonRPCResponse{JSONRPC: "2.0", ID: id, Error: errBody})
}

func normalizeToolCallResult(raw json.RawMessage, isError bool) any {
Expand Down
122 changes: 122 additions & 0 deletions internal/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,128 @@ func TestMCPPingPassThrough(t *testing.T) {
}
}

func TestMCPPingCollapsesUpstreamSSEWhenDownstreamDoesNotAcceptEventStream(t *testing.T) {
sseBody := []byte("event: message\n" +
"data: {\"jsonrpc\":\"2.0\",\"method\":\"notifications/progress\",\"params\":{\"progress\":0.5}}\n\n" +
"event: message\n" +
"data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"ok\":true}}\n\n")
svc := &stubService{
upstream: mcp.Upstream{Name: "demo", Mode: mcp.UpstreamModeHTTP},
forward: mcp.ForwardResult{StatusCode: http.StatusOK, Body: sseBody, ContentType: "text/event-stream", ProtocolVersion: "2025-11-25"},
}
h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil)
req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"ping","params":{}}`))
req.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()

h.Routes().ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if ct := w.Header().Get("Content-Type"); !strings.Contains(ct, "application/json") {
t.Fatalf("expected application/json, got %q", ct)
}
if strings.Contains(w.Body.String(), "event: message") {
t.Fatalf("expected collapsed JSON body, got %q", w.Body.String())
}
if !strings.Contains(w.Body.String(), `"ok":true`) {
t.Fatalf("expected forwarded JSON-RPC response, got %s", w.Body.String())
}
}

func TestMCPPingPreservesUpstreamSSEWhenDownstreamAcceptsEventStream(t *testing.T) {
sseBody := []byte("event: message\n" +
"data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"ok\":true}}\n\n")
svc := &stubService{
upstream: mcp.Upstream{Name: "demo", Mode: mcp.UpstreamModeHTTP},
forward: mcp.ForwardResult{StatusCode: http.StatusOK, Body: sseBody, ContentType: "text/event-stream", ProtocolVersion: "2025-11-25"},
}
h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil)
req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"ping","params":{}}`))
req.Header.Set("Accept", "text/event-stream")
w := httptest.NewRecorder()

h.Routes().ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if ct := w.Header().Get("Content-Type"); !strings.Contains(ct, "text/event-stream") {
t.Fatalf("expected text/event-stream, got %q", ct)
}
if !strings.Contains(w.Body.String(), "event: message") {
t.Fatalf("expected raw SSE body, got %q", w.Body.String())
}
}

func TestMCPPingCollapsesUpstreamSSEWhenEventStreamHasZeroQuality(t *testing.T) {
sseBody := []byte("event: message\n" +
"data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"ok\":true}}\n\n")
svc := &stubService{
upstream: mcp.Upstream{Name: "demo", Mode: mcp.UpstreamModeHTTP},
forward: mcp.ForwardResult{StatusCode: http.StatusOK, Body: sseBody, ContentType: "text/event-stream", ProtocolVersion: "2025-11-25"},
}
h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil)
req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"ping","params":{}}`))
req.Header.Set("Accept", "application/json, text/event-stream;q=0")
w := httptest.NewRecorder()

h.Routes().ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
if strings.Contains(w.Body.String(), "event: message") {
t.Fatalf("expected collapsed JSON body, got %q", w.Body.String())
}
if !strings.Contains(w.Body.String(), `"ok":true`) {
t.Fatalf("expected forwarded JSON-RPC response, got %s", w.Body.String())
}
}

func TestMCPPingSSEDecodeErrorPreservesUpstreamStatus(t *testing.T) {
svc := &stubService{
upstream: mcp.Upstream{Name: "demo", Mode: mcp.UpstreamModeHTTP},
forward: mcp.ForwardResult{StatusCode: http.StatusBadGateway, Body: []byte(": keepalive\n\n"), ContentType: "text/event-stream", ProtocolVersion: "2025-11-25"},
}
h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil)
req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"ping","params":{}}`))
req.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()

h.Routes().ServeHTTP(w, req)

if w.Code != http.StatusBadGateway {
t.Fatalf("expected 502, got %d", w.Code)
}
if !strings.Contains(w.Body.String(), "no JSON-RPC response in SSE stream") {
t.Fatalf("expected decode error body, got %s", w.Body.String())
}
}

func TestMCPForwardedNotificationReturnsAcceptedWithoutSSEBody(t *testing.T) {
sseBody := []byte("event: message\n" +
"data: {\"jsonrpc\":\"2.0\",\"id\":5,\"result\":{\"unrelated\":true}}\n\n")
svc := &stubService{
upstream: mcp.Upstream{Name: "demo", Mode: mcp.UpstreamModeHTTP},
forward: mcp.ForwardResult{StatusCode: http.StatusOK, Body: sseBody, ContentType: "text/event-stream", ProtocolVersion: "2025-11-25"},
}
h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil)
req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","method":"notifications/custom","params":{"x":1}}`))
req.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()

h.Routes().ServeHTTP(w, req)

if w.Code != http.StatusAccepted {
t.Fatalf("expected 202, got %d", w.Code)
}
if body := strings.TrimSpace(w.Body.String()); body != "" {
t.Fatalf("expected empty notification response body, got %q", body)
}
}

func TestMCPUnknownNotificationPassThroughFallbackAccepted(t *testing.T) {
svc := &stubService{fwdErr: context.Canceled}
h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil)
Expand Down
Loading
Loading