diff --git a/README.md b/README.md index 0c1dc49..023e548 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,28 @@ Atryum mediates three kinds of tool calls: These paths converge on a single service so rules, audit, and the UI work identically regardless of how the call arrived. +MCP task support is available for clients that opt into async tool calls: + +- `tools/list` annotates tool metadata with `execution.taskSupport: "optional"`. +- task-augmented `tools/call` requests return an MCP task handle immediately instead of blocking on manual approval. +- `tasks/get` maps the durable invocation row to MCP task state. +- `tasks/result` waits for completion and can respond as JSON or as `text/event-stream` when the client asks for SSE. + +Task negotiation behavior: +- Atryum records a lightweight MCP session profile from `initialize`, keyed by `MCP-Session-Id` when present +- sessions negotiated as `2025-11-25` see approval-gated tools as `execution.taskSupport: "required"` +- older sessions keep `execution.taskSupport: "optional"` for compatibility +- if a `2025-11-25` session calls a required tool without `params.task`, Atryum returns JSON-RPC `-32601` immediately instead of waiting for approval +- successful task-augmented calls are recorded on the session profile as evidence that the harness actually supports tasks + +Task status mapping: + +- `pending_approval` -> `input_required` +- `executing` / `approved` / `received` -> `working` +- `succeeded` -> `completed` +- `denied` / `failed` -> `failed` +- `cancelled` / `expired` -> `cancelled` + ## The rule engine Rules live in the `approval_rules` table and are evaluated in priority order (lowest `rule_order` first). Each rule has: @@ -135,6 +157,16 @@ After first-run bootstrap, edit MCP servers through the UI/API; TOML `[[upstream `server.database_url` selects the storage provider by URL scheme. `postgres://` and `postgresql://` use PostgreSQL via pgx stdlib; `sqlite://`, `file:`, an empty URL, or a bare path use SQLite. Normal tests do not require PostgreSQL; run the optional store integration test with `ATRYUM_POSTGRES_TESTS=1 go test ./internal/store`. +Transient MCP session/profile state uses a separate KV store configuration: + +```toml +[kv] +url = "" +default_ttl_seconds = 3600 +``` + +Empty or `memory://` uses the in-process memory store. Use `redis://` or `rediss://` for a shared KV store across load-balanced Atryum pods. + When `backend.base_url` is empty, the ValidMind backend connection check is skipped for local standalone runs. When it is set, startup fails if credentials are missing or `GET /api/atryum/unstable/connection` is rejected. Environment variables override TOML: `VM_BASE_URL`, `VM_MACHINE_KEY`, `VM_MACHINE_SECRET`, and `VM_CONNECTION_TIMEOUT_SECONDS`. ## Running diff --git a/atryum.example.toml b/atryum.example.toml index f8e375a..0329280 100644 --- a/atryum.example.toml +++ b/atryum.example.toml @@ -94,6 +94,12 @@ secret = "" [auth_debug] skip_verify = false +[kv] +# Empty or memory:// uses an in-process store. Use redis:// or rediss:// for a +# shared store across load-balanced Atryum pods. +url = "" +default_ttl_seconds = 3600 + # Bootstrap-only upstream definitions. # On startup, Atryum seeds mcp_servers from these entries only when the DB is empty. # After bootstrap, runtime resolution uses SQLite as the source of truth. diff --git a/cmd/atryum/main.go b/cmd/atryum/main.go index 1f261ac..cfb2da3 100644 --- a/cmd/atryum/main.go +++ b/cmd/atryum/main.go @@ -22,6 +22,7 @@ import ( "atryum/internal/config" "atryum/internal/invocation" "atryum/internal/invocation/policy" + "atryum/internal/kv" "atryum/internal/managedagents" "atryum/internal/mcp" "atryum/internal/store" @@ -218,7 +219,15 @@ func runServer(args []string) error { if backendClient != nil { syncAgentsFn = syncAgents } - handler := api.NewHandler(service, serverAdmin, policyRegistry, rulesRepo, agentsRepo, agentSyncSettingsRepo, llmConfigsRepo, syncAgentsFn, backendClient, localEvaluator) + kvStore, err := kv.NewStore(cfg.KV.URL) + if err != nil { + return fmt.Errorf("kv store: %w", err) + } + kvTTL := time.Duration(cfg.KV.DefaultTTLSeconds) * time.Second + if kvTTL <= 0 { + kvTTL = time.Hour + } + handler := api.NewHandler(service, serverAdmin, policyRegistry, rulesRepo, agentsRepo, agentSyncSettingsRepo, llmConfigsRepo, syncAgentsFn, backendClient, localEvaluator, api.WithKVStore(kvStore, kvTTL)) authValidator, err := auth.NewValidator(cfg.Auth, nil) if err != nil { diff --git a/go.mod b/go.mod index 6d234cc..3cc5709 100644 --- a/go.mod +++ b/go.mod @@ -8,10 +8,12 @@ require ( github.com/golang-jwt/jwt/v5 v5.3.1 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.9.2 + github.com/redis/go-redis/v9 v9.19.0 modernc.org/sqlite v1.35.0 ) require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect @@ -21,9 +23,10 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + go.uber.org/atomic v1.11.0 // indirect golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect golang.org/x/sync v0.17.0 // indirect - golang.org/x/sys v0.28.0 // indirect + golang.org/x/sys v0.30.0 // indirect golang.org/x/text v0.29.0 // indirect modernc.org/libc v1.61.13 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/go.sum b/go.sum index 0e860a7..58acfe6 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,12 @@ github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0 github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM= github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -21,6 +27,8 @@ github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw= github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= +github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw= github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o= github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk= @@ -31,6 +39,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.19.0 h1:XPVaaPSnG6RhYf7p+rmSa9zZfeVAnWsH5h3lxthOm/k= +github.com/redis/go-redis/v9 v9.19.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -39,6 +49,10 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs= +github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo= golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= @@ -46,8 +60,8 @@ golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= diff --git a/internal/api/handlers.go b/internal/api/handlers.go index daee08e..ef4c6d1 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -25,6 +25,7 @@ import ( backendclient "atryum/internal/backend" "atryum/internal/invocation" "atryum/internal/invocation/policy" + "atryum/internal/kv" "atryum/internal/managedagents" "atryum/internal/mcp" authprovider "atryum/internal/mcp/auth_provider" @@ -36,10 +37,12 @@ var webFS embed.FS type service interface { Invoke(ctx context.Context, req invocation.CreateInvocationRequest) (invocation.InvocationResponse, error) + InvokeAsync(ctx context.Context, req invocation.CreateInvocationRequest, opts invocation.TaskCreateOptions) (invocation.InvocationResponse, error) ListTools(ctx context.Context, server string) ([]mcp.Tool, error) ListAllTools(ctx context.Context) ([]mcp.Tool, error) ResolveToolServer(ctx context.Context, toolName string) (string, error) Get(ctx context.Context, id string) (invocation.InvocationResponse, error) + WaitForCompletion(ctx context.Context, id string) (invocation.InvocationResponse, error) List(ctx context.Context, filter invocation.InvocationListFilter) (invocation.InvocationListResponse, error) ListAgentIDs(ctx context.Context) ([]string, error) Events(ctx context.Context, invocationID string, filter invocation.EventListFilter) (invocation.EventListResponse, error) @@ -77,6 +80,7 @@ type rulesRepo interface { Create(ctx context.Context, rule store.Rule) error Get(ctx context.Context, id string) (store.Rule, error) List(ctx context.Context) ([]store.Rule, error) + ListApprovalRules(ctx context.Context) ([]invocation.ApprovalRule, error) NextOrder(ctx context.Context) (int, error) Update(ctx context.Context, rule store.Rule) error Delete(ctx context.Context, id string) error @@ -149,6 +153,20 @@ type Handler struct { // managedAgents is the optional Claude Managed Agents events bridge. // nil when not configured (no anthropic api key). managedAgents managedAgentsAdmin + + kvStore kv.Store + mcpSessionTTL time.Duration +} + +type HandlerOption func(*Handler) + +type mcpSessionProfile struct { + SessionID string + ProtocolVersion string + ClientDeclaredTasks bool + SawTaskAugmentedToolCall bool + SawSyncCallToRequiredTaskTool bool + LastSeen time.Time } // managedAgentsAdmin is the slice of the managed-agents service the admin API @@ -459,7 +477,7 @@ type invocationStreamEnvelope struct { Items []invocation.InvocationResponse `json:"items"` } -func NewHandler(svc service, serverSvc serverService, policyRegistry *policy.Registry, rules rulesRepo, agents agentsRepo, agentSyncSettings agentSyncSettingsRepo, llmConfigs llmConfigsRepo, syncAgents func(ctx context.Context) error, bc *backendclient.Client, localSummarizer localInvocationSummarizer) *Handler { +func NewHandler(svc service, serverSvc serverService, policyRegistry *policy.Registry, rules rulesRepo, agents agentsRepo, agentSyncSettings agentSyncSettingsRepo, llmConfigs llmConfigsRepo, syncAgents func(ctx context.Context) error, bc *backendclient.Client, localSummarizer localInvocationSummarizer, opts ...HandlerOption) *Handler { staticSub, err := fs.Sub(webFS, "web") if err != nil { panic(err) @@ -469,7 +487,28 @@ func NewHandler(svc service, serverSvc serverService, policyRegistry *policy.Reg if f, ok := svc.(mcpEnvelopeForwarder); ok { forwarder = f } - return &Handler{svc: svc, serverSvc: serverSvc, policyRegistry: policyRegistry, rulesRepo: rules, agentsRepo: agents, agentSyncSettingsRepo: agentSyncSettings, llmConfigsRepo: llmConfigs, backendClient: bc, summarizeClient: bc, localSummarizer: localSummarizer, syncAgentsFn: syncAgents, forwarder: forwarder, staticHTTP: http.FileServer(http.FS(staticSub)), debug: debug, clientInfoCache: make(map[string]clientInfoSnapshot)} + handler := &Handler{svc: svc, serverSvc: serverSvc, policyRegistry: policyRegistry, rulesRepo: rules, agentsRepo: agents, agentSyncSettingsRepo: agentSyncSettings, llmConfigsRepo: llmConfigs, backendClient: bc, summarizeClient: bc, localSummarizer: localSummarizer, syncAgentsFn: syncAgents, forwarder: forwarder, staticHTTP: http.FileServer(http.FS(staticSub)), debug: debug, clientInfoCache: make(map[string]clientInfoSnapshot), kvStore: kv.NewMemoryStore(), mcpSessionTTL: time.Hour} + for _, opt := range opts { + opt(handler) + } + if handler.kvStore == nil { + handler.kvStore = kv.NewMemoryStore() + } + if handler.mcpSessionTTL <= 0 { + handler.mcpSessionTTL = time.Hour + } + return handler +} + +func WithKVStore(store kv.Store, defaultTTL time.Duration) HandlerOption { + return func(h *Handler) { + if store != nil { + h.kvStore = store + } + if defaultTTL > 0 { + h.mcpSessionTTL = defaultTTL + } + } } // SetAuthValidator installs the inbound auth validator. When non-nil, the @@ -887,7 +926,13 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server req.JSONRPC = "2.0" } requestID := compactRequestID(req.ID) + headerProtocolVersion := normalizeProtocolVersion(r.Header.Get("MCP-Protocol-Version")) protocolVersion := negotiateProtocolVersion(req.Method, r.Header.Get("MCP-Protocol-Version"), req.Params) + if req.Method != "initialize" && headerProtocolVersion == "" { + if profile := h.mcpProfileForRequest(r, server); profile != nil && profile.ProtocolVersion != "" { + protocolVersion = profile.ProtocolVersion + } + } h.debugf("server-side mcp request server=%s method=%s id=%s", server, req.Method, requestID) defer func() { h.debugf("server-side mcp complete server=%s method=%s id=%s duration_ms=%d", server, req.Method, requestID, time.Since(started).Milliseconds()) @@ -897,6 +942,10 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server switch req.Method { case "initialize": h.recordInitializeClientInfo(r, req.Params) + profile, sessionID := h.recordMCPInitialize(r, server, protocolVersion, req.Params) + if sessionID != "" { + w.Header().Set("MCP-Session-Id", sessionID) + } result := map[string]any{ "protocolVersion": protocolVersion, "serverInfo": map[string]any{ @@ -905,9 +954,21 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server }, "capabilities": map[string]any{ "tools": map[string]any{}, + "tasks": map[string]any{ + "requests": map[string]any{ + "tools": map[string]any{ + "call": map[string]any{}, + }, + }, + }, }, "instructions": atryumInitializeInstructions, } + if profile != nil { + result["_meta"] = map[string]any{ + "atryum/client-tasks-declared": profile.ClientDeclaredTasks, + } + } h.writeRPCResult(w, req.ID, result) case "notifications/initialized": w.WriteHeader(http.StatusAccepted) @@ -925,12 +986,15 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server } tools = appendAtryumTools(tools) _ = h.emitTraceEvent(r.Context(), server, "mcp.tools.list", map[string]any{"tool_count": len(tools), "request_id": requestID}) - annotated := h.annotateToolsWithPolicy(r.Context(), server, tools) - h.writeRPCResult(w, req.ID, map[string]any{"tools": annotated}) + profile := h.mcpProfileForRequest(r, server) + h.writeRPCResult(w, req.ID, map[string]any{"tools": h.withTaskSupport(r.Context(), tools, server, profile)}) case "tools/call": var params struct { Name string `json:"name"` Arguments map[string]any `json:"arguments"` + Task *struct { + TTL *int64 `json:"ttl,omitempty"` + } `json:"task,omitempty"` } if err := json.Unmarshal(req.Params, ¶ms); err != nil { h.writeRPCError(w, req.ID, -32602, "invalid params") @@ -966,6 +1030,24 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server toolReq.ClientName = snap.Name toolReq.ClientVersion = snap.Version } + if params.Task != nil { + h.markSawTaskAugmentedToolCall(r, server) + resp, err := h.svc.InvokeAsync(r.Context(), toolReq, invocation.TaskCreateOptions{TTLMillis: params.Task.TTL}) + if err != nil { + h.writeRPCError(w, req.ID, -32000, err.Error()) + return + } + tracePayload := map[string]any{"request_id": requestID, "status": resp.Status, "invocation_id": resp.InvocationID, "tool": params.Name, "task": true} + _ = h.emitTraceEvent(r.Context(), server, "mcp.tools.call", tracePayload) + h.writeRPCResult(w, req.ID, createTaskResult(resp, params.Task.TTL)) + return + } + profile := h.mcpProfileForRequest(r, server) + if h.requiresTaskCall(r.Context(), profile, callServer, params.Name) { + h.markSawSyncCallToRequiredTaskTool(r, server) + h.writeRPCError(w, req.ID, -32601, fmt.Sprintf("tool %q requires task-augmented tools/call; retry with params.task and poll tasks/get or tasks/result", params.Name)) + return + } resp, err := h.svc.Invoke(r.Context(), toolReq) if err != nil { h.writeRPCError(w, req.ID, -32000, err.Error()) @@ -982,6 +1064,42 @@ func (h *Handler) handleMCPProxy(w http.ResponseWriter, r *http.Request, server return } h.writeRPCResult(w, req.ID, normalizeToolCallResult(resp.Result, false)) + case "tasks/get": + var params struct { + TaskID string `json:"taskId"` + } + if err := json.Unmarshal(req.Params, ¶ms); err != nil || strings.TrimSpace(params.TaskID) == "" { + h.writeRPCError(w, req.ID, -32602, "invalid params") + return + } + resp, err := h.svc.Get(r.Context(), params.TaskID) + if err != nil { + h.writeRPCError(w, req.ID, -32000, err.Error()) + return + } + h.writeRPCResult(w, req.ID, taskFromInvocation(resp, nil)) + case "tasks/result": + var params struct { + TaskID string `json:"taskId"` + } + if err := json.Unmarshal(req.Params, ¶ms); err != nil || strings.TrimSpace(params.TaskID) == "" { + h.writeRPCError(w, req.ID, -32602, "invalid params") + return + } + resp, err := h.svc.WaitForCompletion(r.Context(), params.TaskID) + if err != nil { + if r.Context().Err() != nil { + return + } + h.writeRPCError(w, req.ID, -32000, err.Error()) + return + } + result := taskResultFromInvocation(resp) + if acceptsEventStream(r) { + h.writeRPCResultSSE(w, req.ID, result) + return + } + h.writeRPCResult(w, req.ID, result) default: forwarded, forwardedOK := h.forwardProxyEnvelope(r.Context(), server, req, protocolVersion) if !forwardedOK { @@ -1199,6 +1317,7 @@ type annotatedTool struct { Description string `json:"description,omitempty"` InputSchema json.RawMessage `json:"inputSchema,omitempty"` Annotations *atryumAnnotations `json:"annotations,omitempty"` + Execution map[string]any `json:"execution,omitempty"` } type atryumAnnotations struct { @@ -1219,14 +1338,14 @@ func (h *Handler) annotateToolsWithPolicy(ctx context.Context, server string, to out := make([]any, len(tools)) if h.rulesRepo == nil || strings.TrimSpace(server) == "" { for i, t := range tools { - out[i] = t + out[i] = toolListEntry(t, "", "") } return out } rules, err := h.rulesRepo.List(ctx) if err != nil { for i, t := range tools { - out[i] = t + out[i] = toolListEntry(t, "", "") } return out } @@ -1234,7 +1353,7 @@ func (h *Handler) annotateToolsWithPolicy(ctx context.Context, server string, to agentCUID := h.resolveAgentRecordForRules(ctx, agentID) for i, t := range tools { if t.Name == agentRulesToolName { - out[i] = t + out[i] = toolListEntry(t, "", "") continue } action, matched := effectiveActionForTool(rules, server, t.Name, agentID, agentCUID) @@ -1247,19 +1366,34 @@ func (h *Handler) annotateToolsWithPolicy(ctx context.Context, server string, to desc = prefix + desc } } - out[i] = annotatedTool{ - Name: t.Name, - Description: desc, - InputSchema: t.InputSchema, - Annotations: &atryumAnnotations{Atryum: atryumToolPolicy{ - EffectiveAction: action, - MatchedRuleID: matched, - }}, - } + out[i] = toolListEntry(t, desc, action, matched) } return out } +func toolListEntry(tool mcp.Tool, description string, policyParts ...string) annotatedTool { + if description == "" { + description = tool.Description + } + entry := annotatedTool{ + Name: tool.Name, + Description: description, + InputSchema: tool.InputSchema, + Execution: map[string]any{"taskSupport": "optional"}, + } + if len(policyParts) > 0 && policyParts[0] != "" { + matched := "" + if len(policyParts) > 1 { + matched = policyParts[1] + } + entry.Annotations = &atryumAnnotations{Atryum: atryumToolPolicy{ + EffectiveAction: policyParts[0], + MatchedRuleID: matched, + }} + } + return entry +} + // effectiveActionForTool returns the action of the first enabled rule that // matches (server, tool, agentID, agentCUID), mirroring invocation.matchRules priority order. // When no rule matches, it returns RuleActionHumanApproval (the default). @@ -3321,6 +3455,18 @@ func (h *Handler) writeRPCResult(w http.ResponseWriter, id json.RawMessage, resu writeJSON(w, http.StatusOK, jsonRPCResponse{JSONRPC: "2.0", ID: id, Result: body}) } +func (h *Handler) writeRPCResultSSE(w http.ResponseWriter, id json.RawMessage, result any) { + body, _ := json.Marshal(jsonRPCResponse{JSONRPC: "2.0", ID: id, Result: mustRawJSON(result)}) + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(w, "event: message\ndata: %s\n\n", body) + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } +} + func (h *Handler) forwardProxyEnvelope(ctx context.Context, server string, envelope mcp.Envelope, protocolVersion string) (mcp.ForwardResult, bool) { if h.forwarder == nil { return mcp.ForwardResult{}, false @@ -3492,6 +3638,290 @@ func sessionKeyFromRequest(r *http.Request) string { return "addr:" + host + "|ua:" + ua } +func mustRawJSON(v any) json.RawMessage { + b, _ := json.Marshal(v) + return b +} + +func createTaskResult(resp invocation.InvocationResponse, ttlMillis *int64) map[string]any { + result := map[string]any{ + "task": taskFromInvocation(resp, ttlMillis), + } + result["_meta"] = map[string]any{ + "io.modelcontextprotocol/model-immediate-response": taskImmediateResponse(resp), + } + return result +} + +func taskFromInvocation(resp invocation.InvocationResponse, ttlMillis *int64) map[string]any { + lastUpdated := resp.SubmittedAt + if resp.CompletedAt != nil { + lastUpdated = *resp.CompletedAt + } + task := map[string]any{ + "taskId": resp.InvocationID, + "status": taskStatusFromInvocation(resp.Status), + "statusMessage": taskStatusMessage(resp), + "createdAt": resp.SubmittedAt.UTC().Format(time.RFC3339), + "lastUpdatedAt": lastUpdated.UTC().Format(time.RFC3339), + "pollInterval": 2000, + "ttl": nil, + } + if ttlMillis != nil { + task["ttl"] = *ttlMillis + } + return task +} + +func taskResultFromInvocation(resp invocation.InvocationResponse) any { + var raw json.RawMessage + isError := false + if len(resp.Error) > 0 { + raw = resp.Error + isError = true + } else { + raw = resp.Result + } + normalized := normalizeToolCallResult(raw, isError) + body, ok := normalized.(map[string]any) + if !ok { + return normalized + } + meta, _ := body["_meta"].(map[string]any) + if meta == nil { + meta = map[string]any{} + } + meta["io.modelcontextprotocol/related-task"] = map[string]any{"taskId": resp.InvocationID} + body["_meta"] = meta + return body +} + +func taskStatusFromInvocation(status invocation.Status) string { + switch status { + case invocation.StatusSucceeded: + return "completed" + case invocation.StatusCancelled, invocation.StatusExpired: + return "cancelled" + case invocation.StatusDenied, invocation.StatusFailed: + return "failed" + case invocation.StatusPendingApproval: + return "input_required" + default: + return "working" + } +} + +func taskStatusMessage(resp invocation.InvocationResponse) string { + switch resp.Status { + case invocation.StatusPendingApproval: + return "Awaiting approval." + case invocation.StatusApproved: + return "Approved. Execution will begin shortly." + case invocation.StatusExecuting: + return "The operation is now in progress." + case invocation.StatusSucceeded: + return "The operation completed successfully." + case invocation.StatusDenied: + return "The operation was denied by policy or a reviewer." + case invocation.StatusCancelled: + return "The operation was cancelled." + case invocation.StatusExpired: + return "The operation expired before completion." + case invocation.StatusFailed: + return "The operation failed." + default: + return "The operation is now in progress." + } +} + +func taskImmediateResponse(resp invocation.InvocationResponse) string { + switch resp.Status { + case invocation.StatusPendingApproval: + return "Tool call submitted and awaiting approval. Continue with other work while the task is pending." + case invocation.StatusExecuting, invocation.StatusApproved: + return "Tool call accepted as a task and is running asynchronously." + case invocation.StatusDenied, invocation.StatusFailed: + return "Tool call completed with an error. Retrieve the task result for details." + default: + return "Tool call accepted as a task." + } +} + +func (h *Handler) withTaskSupport(ctx context.Context, tools []mcp.Tool, server string, profile *mcpSessionProfile) []any { + out := h.annotateToolsWithPolicy(ctx, server, tools) + for i, tool := range tools { + taskSupport := "optional" + if h.requiresTaskCall(ctx, profile, server, tool.Name) { + taskSupport = "required" + } + switch entry := out[i].(type) { + case annotatedTool: + entry.Execution = map[string]any{"taskSupport": taskSupport} + out[i] = entry + case mcp.Tool: + item := toolListEntry(entry, "", "") + item.Execution = map[string]any{"taskSupport": taskSupport} + out[i] = item + } + } + return out +} + +func (h *Handler) requiresTaskCall(ctx context.Context, profile *mcpSessionProfile, server string, tool string) bool { + if profile == nil || profile.ProtocolVersion != mcp.MCPProtocolVersion2025 { + return false + } + return h.toolRequiresApproval(ctx, server, tool) +} + +func (h *Handler) toolRequiresApproval(ctx context.Context, server string, tool string) bool { + if h.rulesRepo != nil { + if rules, err := h.rulesRepo.ListApprovalRules(ctx); err == nil { + if matched := firstMatchingApprovalRule(rules, server, tool, ""); matched != nil { + return matched.Action == invocation.RuleActionHumanApproval + } + } + } + if h.policyRegistry == nil { + return false + } + active := h.policyRegistry.Active() + if active == nil { + return true + } + return active.ID() == "manual_approval" || active.ID() == "registry" +} + +func firstMatchingApprovalRule(rules []invocation.ApprovalRule, server string, tool string, user string) *invocation.ApprovalRule { + for i := range rules { + rule := &rules[i] + if !rule.Enabled { + continue + } + if !matchesPatternList(rule.ServerPatterns, server) { + continue + } + if !matchesPatternList(rule.ToolPatterns, tool) { + continue + } + if !matchesUserPattern(rule.AgentIDPattern, user) { + continue + } + return rule + } + return nil +} + +func matchesPatternList(patterns []string, value string) bool { + if len(patterns) == 0 { + return true + } + for _, pattern := range patterns { + if pattern == "*" || pattern == value { + return true + } + } + return false +} + +func matchesUserPattern(pattern string, user string) bool { + return pattern == "" || pattern == "*" || pattern == user +} + +func (h *Handler) recordMCPInitialize(r *http.Request, server string, protocolVersion string, params json.RawMessage) (*mcpSessionProfile, string) { + sessionID := strings.TrimSpace(r.Header.Get("MCP-Session-Id")) + if sessionID == "" { + var err error + sessionID, err = randomToken(18) + if err != nil { + sessionID = fallbackMCPSessionKey(r, server) + } + } + profile := &mcpSessionProfile{ + SessionID: sessionID, + ProtocolVersion: protocolVersion, + ClientDeclaredTasks: clientDeclaredTasksCapability(params), + LastSeen: time.Now().UTC(), + } + _ = h.kvStore.Set(r.Context(), mcpSessionKey(sessionID), profile, h.mcpSessionTTL) + _ = h.kvStore.Set(r.Context(), mcpSessionKey(fallbackMCPSessionKey(r, server)), profile, h.mcpSessionTTL) + return profile, sessionID +} + +func (h *Handler) mcpProfileForRequest(r *http.Request, server string) *mcpSessionProfile { + key := strings.TrimSpace(r.Header.Get("MCP-Session-Id")) + if key == "" { + key = fallbackMCPSessionKey(r, server) + } + var profile mcpSessionProfile + found, err := h.kvStore.Get(r.Context(), mcpSessionKey(key), &profile) + if err != nil || !found { + if version := normalizeProtocolVersion(r.Header.Get("MCP-Protocol-Version")); version != "" { + return &mcpSessionProfile{ProtocolVersion: version, LastSeen: time.Now().UTC()} + } + return nil + } + profile.LastSeen = time.Now().UTC() + _ = h.kvStore.Set(r.Context(), mcpSessionKey(key), profile, h.mcpSessionTTL) + return &profile +} + +func (h *Handler) markSawTaskAugmentedToolCall(r *http.Request, server string) { + h.updateMCPProfile(r, server, func(profile *mcpSessionProfile) { + profile.SawTaskAugmentedToolCall = true + }) +} + +func (h *Handler) markSawSyncCallToRequiredTaskTool(r *http.Request, server string) { + h.updateMCPProfile(r, server, func(profile *mcpSessionProfile) { + profile.SawSyncCallToRequiredTaskTool = true + }) +} + +func (h *Handler) updateMCPProfile(r *http.Request, server string, update func(*mcpSessionProfile)) { + key := strings.TrimSpace(r.Header.Get("MCP-Session-Id")) + if key == "" { + key = fallbackMCPSessionKey(r, server) + } + _, _ = h.kvStore.Update(r.Context(), mcpSessionKey(key), h.mcpSessionTTL, func(payload []byte) ([]byte, error) { + var profile mcpSessionProfile + if err := json.Unmarshal(payload, &profile); err != nil { + return nil, err + } + update(&profile) + profile.LastSeen = time.Now().UTC() + return json.Marshal(profile) + }) + if key != fallbackMCPSessionKey(r, server) { + if profile := h.mcpProfileForRequest(r, server); profile != nil { + _ = h.kvStore.Set(r.Context(), mcpSessionKey(fallbackMCPSessionKey(r, server)), profile, h.mcpSessionTTL) + } + } +} + +func fallbackMCPSessionKey(r *http.Request, server string) string { + return "remote:" + server + ":" + r.RemoteAddr +} + +func mcpSessionKey(sessionID string) string { + return "mcp_session:" + sessionID +} + +func clientDeclaredTasksCapability(params json.RawMessage) bool { + var payload struct { + Capabilities map[string]json.RawMessage `json:"capabilities"` + } + if err := json.Unmarshal(params, &payload); err != nil { + return false + } + raw, ok := payload.Capabilities["tasks"] + return ok && len(raw) > 0 && string(raw) != "null" +} + +func acceptsEventStream(r *http.Request) bool { + return strings.Contains(strings.ToLower(r.Header.Get("Accept")), "text/event-stream") +} + func compactRequestID(id json.RawMessage) string { if len(id) == 0 { return "" diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go index 0b503a5..e4ae514 100644 --- a/internal/api/handlers_test.go +++ b/internal/api/handlers_test.go @@ -15,23 +15,26 @@ import ( "atryum/internal/auth" backendclient "atryum/internal/backend" "atryum/internal/invocation" + "atryum/internal/invocation/policy" "atryum/internal/managedagents" "atryum/internal/mcp" "atryum/internal/store" ) type stubService struct { - tools []mcp.Tool - invoke invocation.InvocationResponse - invErr error - getErr error - setResp invocation.InvocationResponse - setID string - setText string - listErr error - upstream mcp.Upstream - forward mcp.ForwardResult - fwdErr error + tools []mcp.Tool + invoke invocation.InvocationResponse + invokeAsync invocation.InvocationResponse + waitForCompletion invocation.InvocationResponse + invErr error + getErr error + setResp invocation.InvocationResponse + setID string + setText string + listErr error + upstream mcp.Upstream + forward mcp.ForwardResult + fwdErr error invokedReq *invocation.CreateInvocationRequest invokedCtx context.Context @@ -42,6 +45,13 @@ func (s *stubService) Invoke(ctx context.Context, req invocation.CreateInvocatio s.invokedCtx = ctx return s.invoke, s.invErr } +func (s *stubService) InvokeAsync(_ context.Context, req invocation.CreateInvocationRequest, _ invocation.TaskCreateOptions) (invocation.InvocationResponse, error) { + s.invokedReq = &req + if s.invokeAsync.InvocationID != "" { + return s.invokeAsync, s.invErr + } + return s.invoke, s.invErr +} func (s *stubService) ListTools(context.Context, string) ([]mcp.Tool, error) { return s.tools, s.listErr } @@ -54,6 +64,12 @@ func (s *stubService) ResolveToolServer(_ context.Context, _ string) (string, er func (s *stubService) Get(_ context.Context, _ string) (invocation.InvocationResponse, error) { return s.invoke, s.getErr } +func (s *stubService) WaitForCompletion(context.Context, string) (invocation.InvocationResponse, error) { + if s.waitForCompletion.InvocationID != "" { + return s.waitForCompletion, nil + } + return s.invoke, nil +} func (s *stubService) List(context.Context, invocation.InvocationListFilter) (invocation.InvocationListResponse, error) { return invocation.InvocationListResponse{Items: []invocation.InvocationResponse{s.invoke}, Total: 1, Limit: 50}, nil } @@ -128,6 +144,20 @@ func (s *stubRulesRepo) Get(_ context.Context, id string) (store.Rule, error) { func (s *stubRulesRepo) List(context.Context) ([]store.Rule, error) { return s.rules, s.err } +func (s *stubRulesRepo) ListApprovalRules(context.Context) ([]invocation.ApprovalRule, error) { + out := make([]invocation.ApprovalRule, 0, len(s.rules)) + for _, rule := range s.rules { + out = append(out, invocation.ApprovalRule{ + ID: rule.ID, + ServerPatterns: rule.ServerPatterns, + ToolPatterns: rule.ToolPatterns, + AgentIDPattern: rule.AgentIDPattern, + Action: rule.Action, + Enabled: rule.Enabled, + }) + } + return out, s.err +} func (s *stubRulesRepo) NextOrder(context.Context) (int, error) { return len(s.rules), nil } func (s *stubRulesRepo) Update(context.Context, store.Rule) error { return nil } func (s *stubRulesRepo) Delete(context.Context, string) error { return nil } @@ -590,6 +620,50 @@ func TestMCPToolsList(t *testing.T) { } } +func TestMCPToolsListMarksApprovalToolsRequiredFor2025Session(t *testing.T) { + registry := policy.NewRegistry(policy.ManualApprovalProvider{}) + if err := registry.SetActive("manual_approval"); err != nil { + t.Fatal(err) + } + h := NewHandler(&stubService{tools: []mcp.Tool{{Name: "demo_tool"}}}, stubServerService{}, registry, nil, nil, nil, nil, nil, nil, nil) + + initReq := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{"tasks":{}}}}`)) + initReq.Header.Set("MCP-Session-Id", "session-required") + h.Routes().ServeHTTP(httptest.NewRecorder(), initReq) + + req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}`)) + req.Header.Set("MCP-Session-Id", "session-required") + w := httptest.NewRecorder() + + h.Routes().ServeHTTP(w, req) + + if !strings.Contains(w.Body.String(), `"taskSupport":"required"`) { + t.Fatalf("expected required task support, got %s", w.Body.String()) + } +} + +func TestMCPToolsListKeepsApprovalToolsOptionalForLegacySession(t *testing.T) { + registry := policy.NewRegistry(policy.ManualApprovalProvider{}) + if err := registry.SetActive("manual_approval"); err != nil { + t.Fatal(err) + } + h := NewHandler(&stubService{tools: []mcp.Tool{{Name: "demo_tool"}}}, stubServerService{}, registry, nil, nil, nil, nil, nil, nil, nil) + + initReq := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{}}}`)) + initReq.Header.Set("MCP-Session-Id", "session-legacy") + h.Routes().ServeHTTP(httptest.NewRecorder(), initReq) + + req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}`)) + req.Header.Set("MCP-Session-Id", "session-legacy") + w := httptest.NewRecorder() + + h.Routes().ServeHTTP(w, req) + + if !strings.Contains(w.Body.String(), `"taskSupport":"optional"`) { + t.Fatalf("expected optional task support, got %s", w.Body.String()) + } +} + func TestMCPToolsCallInterceptsInvocation(t *testing.T) { now := time.Now().UTC() svc := &stubService{invoke: invocation.InvocationResponse{InvocationID: "inv_123", ServerName: "demo", ToolName: "demo_tool", Status: invocation.StatusSucceeded, Input: json.RawMessage(`{"a":1}`), SubmittedAt: now, CompletedAt: &now, Result: json.RawMessage(`{"content":[{"type":"text","text":"ok"}]}`)}} @@ -613,6 +687,36 @@ func TestMCPToolsCallInterceptsInvocation(t *testing.T) { } } +func TestMCPToolsCallRejectsSyncCallToRequiredTaskTool(t *testing.T) { + registry := policy.NewRegistry(policy.ManualApprovalProvider{}) + if err := registry.SetActive("manual_approval"); err != nil { + t.Fatal(err) + } + now := time.Now().UTC() + svc := &stubService{invoke: invocation.InvocationResponse{InvocationID: "inv_123", ServerName: "demo", ToolName: "demo_tool", Status: invocation.StatusSucceeded, SubmittedAt: now, CompletedAt: &now, Result: json.RawMessage(`{"content":[{"type":"text","text":"ok"}]}`)}} + h := NewHandler(svc, stubServerService{}, registry, nil, nil, nil, nil, nil, nil, nil) + + initReq := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{}}}`)) + initReq.Header.Set("MCP-Session-Id", "session-sync-reject") + h.Routes().ServeHTTP(httptest.NewRecorder(), initReq) + + req := httptest.NewRequest(http.MethodPost, "/mcp/demo", strings.NewReader(`{"jsonrpc":"2.0","id":7,"method":"tools/call","params":{"name":"demo_tool","arguments":{"a":1}}}`)) + req.Header.Set("MCP-Session-Id", "session-sync-reject") + w := httptest.NewRecorder() + + h.Routes().ServeHTTP(w, req) + + if svc.invokedReq != nil { + t.Fatal("expected sync required-task call to be rejected before invocation") + } + if !strings.Contains(w.Body.String(), `"code":-32601`) { + t.Fatalf("expected method-not-found error, got %s", w.Body.String()) + } + if !strings.Contains(w.Body.String(), `requires task-augmented tools/call`) { + t.Fatalf("expected task-required message, got %s", w.Body.String()) + } +} + func TestMCPNoAuthAgentIDQueryHintSetsIdentity(t *testing.T) { now := time.Now().UTC() svc := &stubService{invoke: invocation.InvocationResponse{InvocationID: "inv_123", ServerName: "demo", ToolName: "demo_tool", Status: invocation.StatusSucceeded, SubmittedAt: now, CompletedAt: &now, Result: json.RawMessage(`{"content":[{"type":"text","text":"ok"}]}`)}} @@ -935,6 +1039,66 @@ func TestMCPToolsCallDenialIncludesRulesContext(t *testing.T) { } } +func TestMCPToolsCallWithTaskReturnsCreateTaskResult(t *testing.T) { + now := time.Now().UTC() + svc := &stubService{invokeAsync: invocation.InvocationResponse{ + InvocationID: "inv_task_123", + ServerName: "github", + ToolName: "list_commits", + Status: invocation.StatusPendingApproval, + Input: json.RawMessage(`{"repo":"openai/openai"}`), + SubmittedAt: now, + }} + h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil) + req := httptest.NewRequest(http.MethodPost, "/mcp/github", strings.NewReader(`{"jsonrpc":"2.0","id":9,"method":"tools/call","params":{"name":"list_commits","arguments":{"repo":"openai/openai"},"task":{"ttl":60000}}}`)) + w := httptest.NewRecorder() + + h.Routes().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if !strings.Contains(w.Body.String(), `"taskId":"inv_task_123"`) { + t.Fatalf("expected task id in response, got %s", w.Body.String()) + } + if !strings.Contains(w.Body.String(), `"status":"input_required"`) { + t.Fatalf("expected input_required task status, got %s", w.Body.String()) + } + if !strings.Contains(w.Body.String(), `model-immediate-response`) { + t.Fatalf("expected model immediate response meta, got %s", w.Body.String()) + } +} + +func TestMCPTasksResultCanReturnSSE(t *testing.T) { + now := time.Now().UTC() + svc := &stubService{waitForCompletion: invocation.InvocationResponse{ + InvocationID: "inv_task_456", + ServerName: "github", + ToolName: "list_commits", + Status: invocation.StatusSucceeded, + Input: json.RawMessage(`{"repo":"openai/openai"}`), + SubmittedAt: now, + CompletedAt: &now, + Result: json.RawMessage(`{"content":[{"type":"text","text":"done"}]}`), + }} + h := NewHandler(svc, stubServerService{}, nil, nil, nil, nil, nil, nil, nil, nil) + req := httptest.NewRequest(http.MethodPost, "/mcp/github", strings.NewReader(`{"jsonrpc":"2.0","id":10,"method":"tasks/result","params":{"taskId":"inv_task_456"}}`)) + req.Header.Set("Accept", "text/event-stream") + w := httptest.NewRecorder() + + h.Routes().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if ct := w.Header().Get("Content-Type"); !strings.Contains(ct, "text/event-stream") { + t.Fatalf("expected text/event-stream, got %s", ct) + } + if !strings.Contains(w.Body.String(), `"io.modelcontextprotocol/related-task":{"taskId":"inv_task_456"}`) { + t.Fatalf("expected related-task meta, got %s", w.Body.String()) + } +} + func TestAdminInvocationsResponsesIncludeServerToolAndInput(t *testing.T) { now := time.Now().UTC() svc := &stubService{invoke: invocation.InvocationResponse{InvocationID: "inv_123", ServerName: "demo-server", ToolName: "demo_tool", Status: invocation.StatusSucceeded, Input: json.RawMessage(`{"issue":123,"verbose":true}`), SubmittedAt: now, CompletedAt: &now}} diff --git a/internal/config/config.go b/internal/config/config.go index 7d647f4..def5f08 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,11 +9,11 @@ import ( "atryum/internal/auth" ) - type Config struct { Server ServerConfig `toml:"server"` Backend BackendConfig `toml:"backend"` Defaults DefaultsConfig `toml:"defaults"` + KV KVConfig `toml:"kv"` Policy PolicyConfig `toml:"policy"` Upstreams []UpstreamConfig `toml:"upstreams"` // Auth holds zero or more inbound OAuth bearer-token validators @@ -89,6 +89,11 @@ type AuthDebugConfig struct { SkipVerify bool `toml:"skip_verify"` } +type KVConfig struct { + URL string `toml:"url"` + DefaultTTLSeconds int `toml:"default_ttl_seconds"` +} + // PolicyConfig selects the active approval policy provider at startup. // Valid provider values: "always_approve", "manual_approval", "always_deny". type PolicyConfig struct { @@ -142,6 +147,9 @@ func Load(path string) (Config, error) { Defaults: DefaultsConfig{ RequestTimeoutSeconds: 30, }, + KV: KVConfig{ + DefaultTTLSeconds: 3600, + }, } _, err := toml.DecodeFile(path, &cfg) if err != nil && !os.IsNotExist(err) { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 534ccc7..5976d6b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -73,3 +73,44 @@ func TestLoadMissingConfigUsesDefaultsAndEnv(t *testing.T) { t.Fatalf("Backend.APISecret = %q", cfg.Backend.APISecret) } } + +func TestLoadKVConfig(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "atryum.toml") + if err := os.WriteFile(path, []byte(`[kv] +url = "memory://" +default_ttl_seconds = 120 +`), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if cfg.KV.URL != "memory://" { + t.Fatalf("KV.URL = %q", cfg.KV.URL) + } + if cfg.KV.DefaultTTLSeconds != 120 { + t.Fatalf("KV.DefaultTTLSeconds = %d", cfg.KV.DefaultTTLSeconds) + } +} + +func TestLoadKVConfigDefaults(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "atryum.toml") + if err := os.WriteFile(path, []byte(``), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if cfg.KV.URL != "" { + t.Fatalf("KV.URL default = %q", cfg.KV.URL) + } + if cfg.KV.DefaultTTLSeconds != 3600 { + t.Fatalf("KV.DefaultTTLSeconds default = %d", cfg.KV.DefaultTTLSeconds) + } +} diff --git a/internal/invocation/service.go b/internal/invocation/service.go index b7ec0ff..1c5373b 100644 --- a/internal/invocation/service.go +++ b/internal/invocation/service.go @@ -96,15 +96,15 @@ type SyncSettingsProvider interface { // EvaluateRequest mirrors backend.EvaluateRequest so the service package does // not import the backend package directly. type EvaluateRequest struct { - ModelConfigCUID string `json:"model_config_cuid"` - OrgCUID string `json:"org_cuid,omitempty"` - AgentVMCUID string `json:"agent_vm_cuid,omitempty"` - ConstitutionFieldKey string `json:"constitution_field_key,omitempty"` + ModelConfigCUID string `json:"model_config_cuid"` + OrgCUID string `json:"org_cuid,omitempty"` + AgentVMCUID string `json:"agent_vm_cuid,omitempty"` + ConstitutionFieldKey string `json:"constitution_field_key,omitempty"` // AtryumLLMConfigID references a local LLM config for native evaluation. // When set, the local evaluator is used instead of the VM backend. AtryumLLMConfigID string `json:"atryum_llm_config_id,omitempty"` // Constitution is the agent's governing text sent to the local LLM judge. - Constitution string `json:"constitution,omitempty"` + Constitution string `json:"constitution,omitempty"` ServerName string `json:"server_name"` ToolName string `json:"tool_name"` ToolArgs map[string]any `json:"tool_args,omitempty"` @@ -179,6 +179,10 @@ type Service struct { pendingApprovals map[string]chan approvalDecision } +type TaskCreateOptions struct { + TTLMillis *int64 +} + func NewService( inv invocationRepo, evt eventRepo, @@ -212,29 +216,103 @@ func (s *Service) SetInvocationSummarizer(client SummaryClient) { s.summarizer = client } +func (s *Service) InvokeAsync(ctx context.Context, req CreateInvocationRequest, opts TaskCreateOptions) (InvocationResponse, error) { + inv, upstream, decision, _, aiConfidence, existing, err := s.prepareInvocation(ctx, req) + if err != nil { + return InvocationResponse{}, err + } + if existing { + return s.toResponse(inv), nil + } + switch decision.Disposition { + case policy.DispositionNever: + resp, err := s.denyByPolicy(ctx, inv, decision.Reason, aiConfidence) + if err != nil { + return InvocationResponse{}, err + } + return resp, nil + case policy.DispositionAuto: + inv.Status = StatusExecuting + inv.Approval = newApproval("auto_approved", decision.Reason, aiConfidence) + if err := s.invocations.UpdateResult(ctx, inv); err != nil { + return InvocationResponse{}, err + } + _ = s.events.Create(ctx, Event{ + InvocationID: inv.InvocationID, + EventType: "invocation.executing", + Payload: mustJSON(map[string]any{ + "upstream": upstream.Name, "request_id": req.RequestID, + "input": json.RawMessage(inv.Input), "arguments": json.RawMessage(inv.Input), + "auto_approved": true, "auto_reason": decision.Reason, "task": true, + }), + CreatedAt: time.Now().UTC(), + }) + s.executeInvocationAsync(inv, upstream, req) + return s.toResponse(inv), nil + default: + inv.Status = StatusPendingApproval + if err := s.invocations.UpdateResult(ctx, inv); err != nil { + return InvocationResponse{}, err + } + _ = s.events.Create(ctx, Event{ + InvocationID: inv.InvocationID, + EventType: "invocation.pending_approval", + Payload: mustJSON(map[string]any{ + "tool": req.Tool, "upstream": upstream.Name, + "request_id": req.RequestID, + "input": json.RawMessage(inv.Input), "arguments": json.RawMessage(inv.Input), + "task": true, + }), + CreatedAt: time.Now().UTC(), + }) + return s.toResponse(inv), nil + } +} + func (s *Service) Invoke(ctx context.Context, req CreateInvocationRequest) (InvocationResponse, error) { + inv, upstream, decision, _, aiConfidence, existing, err := s.prepareInvocation(ctx, req) + if err != nil { + return InvocationResponse{}, err + } + if existing { + return s.toResponse(inv), nil + } + switch decision.Disposition { + case policy.DispositionNever: + return s.denyByPolicy(ctx, inv, decision.Reason, aiConfidence) + case policy.DispositionAuto: + return s.executeNow(ctx, inv, upstream, req, decision.Reason, aiConfidence) + default: + // DispositionHuman, DispositionWorkflow, and dispositionAIEscalated all gate + // on a human decision. AI-escalated invocations are already tagged on inv.Approval + // above; waitForHumanApproval will persist the pending_approval status. + return s.waitForHumanApproval(ctx, inv, upstream, req) + } +} + +func (s *Service) prepareInvocation(ctx context.Context, req CreateInvocationRequest) (Invocation, mcp.Upstream, policy.Decision, *string, *float64, bool, error) { if req.Server == "" { - return InvocationResponse{}, fmt.Errorf("server is required") + return Invocation{}, mcp.Upstream{}, policy.Decision{}, nil, nil, false, fmt.Errorf("server is required") } if req.Tool == "" { - return InvocationResponse{}, fmt.Errorf("tool is required") + return Invocation{}, mcp.Upstream{}, policy.Decision{}, nil, nil, false, fmt.Errorf("tool is required") } if req.IdempotencyKey != nil && *req.IdempotencyKey != "" { existing, err := s.invocations.GetByIdempotencyKey(ctx, *req.IdempotencyKey) if err == nil { - return s.toResponse(existing), nil + return existing, mcp.Upstream{Name: existing.Upstream}, policy.Decision{}, existing.MatchedRuleID, nil, true, nil } if err != nil && err != sql.ErrNoRows { - return InvocationResponse{}, err + return Invocation{}, mcp.Upstream{}, policy.Decision{}, nil, nil, false, err } } upstream, err := s.resolver.ResolveContext(ctx, req.Server) if err != nil { - return InvocationResponse{}, err + return Invocation{}, mcp.Upstream{}, policy.Decision{}, nil, nil, false, err } inputJSON, err := json.Marshal(req.Input) if err != nil { - return InvocationResponse{}, err + return Invocation{}, mcp.Upstream{}, policy.Decision{}, nil, nil, false, err } // agentID is the authenticated agent identity from middleware. When auth @@ -266,7 +344,7 @@ func (s *Service) Invoke(ctx context.Context, req CreateInvocationRequest) (Invo inv.ClientVersion = &v } if err := s.invocations.Create(ctx, inv); err != nil { - return InvocationResponse{}, err + return Invocation{}, mcp.Upstream{}, policy.Decision{}, nil, nil, false, err } // Determine disposition: check rules first (fine-grained), then fall back to policy (global). @@ -347,18 +425,7 @@ func (s *Service) Invoke(ctx context.Context, req CreateInvocationRequest) (Invo Payload: mustJSON(receivedPayload), CreatedAt: now, }) - - switch decision.Disposition { - case policy.DispositionNever: - return s.denyByPolicy(ctx, inv, decision.Reason, aiConfidence) - case policy.DispositionAuto: - return s.executeNow(ctx, inv, upstream, req, decision.Reason, aiConfidence) - default: - // DispositionHuman, DispositionWorkflow, and dispositionAIEscalated all gate - // on a human decision. AI-escalated invocations are already tagged on inv.Approval - // above; waitForHumanApproval will persist the pending_approval status. - return s.waitForHumanApproval(ctx, inv, upstream, req) - } + return inv, upstream, decision, matchedRuleID, aiConfidence, false, nil } // resolveAgentRecord looks up the Atryum agent record for the given runtime @@ -734,6 +801,9 @@ func (s *Service) recordExternalDecision(ctx context.Context, invocationID strin return err } _ = s.events.Create(ctx, Event{InvocationID: inv.InvocationID, EventType: "invocation.approved", Payload: mustJSON(map[string]any{"input": json.RawMessage(inv.Input), "arguments": json.RawMessage(inv.Input)}), CreatedAt: now}) + if inv.Upstream != "external" { + go s.executeApprovedInvocation(inv) + } return nil } inv.Status = StatusDenied @@ -993,6 +1063,26 @@ func (s *Service) summarizePendingApproval(invocationID string) { }() } +func (s *Service) WaitForCompletion(ctx context.Context, invocationID string) (InvocationResponse, error) { + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + for { + inv, err := s.invocations.Get(ctx, invocationID) + if err != nil { + return InvocationResponse{}, err + } + resp := s.toResponse(inv) + if isTerminalStatus(resp.Status) { + return resp, nil + } + select { + case <-ctx.Done(): + return InvocationResponse{}, ctx.Err() + case <-ticker.C: + } + } +} + // RecordExecution updates an externally-executed invocation with the outcome // reported by the executor. Valid execStatus values: // @@ -1209,6 +1299,62 @@ func (s *Service) toResponse(inv Invocation) InvocationResponse { return resp } +func (s *Service) executeInvocationAsync(inv Invocation, upstream mcp.Upstream, req CreateInvocationRequest) { + go func() { + _, _ = s.finishExecution(context.Background(), inv, upstream, req) + }() +} + +func (s *Service) executeApprovedInvocation(inv Invocation) { + var input map[string]any + if len(inv.Input) > 0 { + _ = json.Unmarshal(inv.Input, &input) + } + req := CreateInvocationRequest{ + Server: inv.Upstream, + Tool: inv.Tool, + Input: input, + RequestID: inv.RequestID, + IdempotencyKey: inv.IdempotencyKey, + } + upstream, err := s.resolver.ResolveContext(context.Background(), inv.Upstream) + if err != nil { + completed := time.Now().UTC() + inv.Status = StatusFailed + inv.CompletedAt = &completed + inv.Error = mustJSON(map[string]any{"message": err.Error()}) + _ = s.invocations.UpdateResult(context.Background(), inv) + _ = s.events.Create(context.Background(), Event{InvocationID: inv.InvocationID, EventType: "invocation.failed", Payload: inv.Error, CreatedAt: completed}) + return + } + inv.Status = StatusExecuting + if err := s.invocations.UpdateResult(context.Background(), inv); err != nil { + return + } + _ = s.events.Create(context.Background(), Event{ + InvocationID: inv.InvocationID, + EventType: "invocation.executing", + Payload: mustJSON(map[string]any{ + "upstream": upstream.Name, + "request_id": inv.RequestID, + "input": json.RawMessage(inv.Input), + "arguments": json.RawMessage(inv.Input), + "task": true, + }), + CreatedAt: time.Now().UTC(), + }) + _, _ = s.finishExecution(context.Background(), inv, upstream, req) +} + +func isTerminalStatus(status Status) bool { + switch status { + case StatusSucceeded, StatusFailed, StatusDenied, StatusCancelled, StatusExpired: + return true + default: + return false + } +} + func mustJSON(v any) []byte { b, _ := json.Marshal(v) return b diff --git a/internal/kv/kv.go b/internal/kv/kv.go new file mode 100644 index 0000000..4c136a6 --- /dev/null +++ b/internal/kv/kv.go @@ -0,0 +1,191 @@ +package kv + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/url" + "strings" + "sync" + "time" + + "github.com/redis/go-redis/v9" +) + +type Store interface { + Get(ctx context.Context, key string, dest any) (bool, error) + Set(ctx context.Context, key string, value any, ttl time.Duration) error + Update(ctx context.Context, key string, ttl time.Duration, update func([]byte) ([]byte, error)) (bool, error) + Delete(ctx context.Context, key string) error +} + +type MemoryStore struct { + mu sync.RWMutex + items map[string]memoryItem +} + +type RedisStore struct { + client *redis.Client +} + +type memoryItem struct { + value []byte + expiresAt time.Time +} + +func NewMemoryStore() *MemoryStore { + return &MemoryStore{items: make(map[string]memoryItem)} +} + +func NewStore(rawURL string) (Store, error) { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" || strings.EqualFold(rawURL, "memory://") || strings.EqualFold(rawURL, "memory") { + return NewMemoryStore(), nil + } + parsed, err := url.Parse(rawURL) + if err != nil { + return nil, err + } + switch parsed.Scheme { + case "redis", "rediss": + opts, err := redis.ParseURL(rawURL) + if err != nil { + return nil, err + } + return NewRedisStore(opts), nil + default: + return nil, fmt.Errorf("unsupported kv url scheme %q", parsed.Scheme) + } +} + +func NewRedisStore(opts *redis.Options) *RedisStore { + return &RedisStore{client: redis.NewClient(opts)} +} + +func (s *MemoryStore) Get(_ context.Context, key string, dest any) (bool, error) { + s.mu.RLock() + item, ok := s.items[key] + s.mu.RUnlock() + if !ok || item.expired(time.Now()) { + if ok { + _ = s.Delete(context.Background(), key) + } + return false, nil + } + if err := json.Unmarshal(item.value, dest); err != nil { + return false, err + } + return true, nil +} + +func (s *MemoryStore) Set(_ context.Context, key string, value any, ttl time.Duration) error { + payload, err := json.Marshal(value) + if err != nil { + return err + } + s.mu.Lock() + s.items[key] = memoryItem{value: payload, expiresAt: expiresAt(ttl)} + s.mu.Unlock() + return nil +} + +func (s *MemoryStore) Update(_ context.Context, key string, ttl time.Duration, update func([]byte) ([]byte, error)) (bool, error) { + now := time.Now() + s.mu.Lock() + defer s.mu.Unlock() + item, ok := s.items[key] + if !ok || item.expired(now) { + if ok { + delete(s.items, key) + } + return false, nil + } + payload, err := update(append([]byte(nil), item.value...)) + if err != nil { + return false, err + } + item.value = payload + if ttl > 0 { + item.expiresAt = expiresAt(ttl) + } + s.items[key] = item + return true, nil +} + +func (s *MemoryStore) Delete(_ context.Context, key string) error { + s.mu.Lock() + delete(s.items, key) + s.mu.Unlock() + return nil +} + +func (i memoryItem) expired(now time.Time) bool { + return !i.expiresAt.IsZero() && now.After(i.expiresAt) +} + +func expiresAt(ttl time.Duration) time.Time { + if ttl <= 0 { + return time.Time{} + } + return time.Now().UTC().Add(ttl) +} + +func (s *RedisStore) Get(ctx context.Context, key string, dest any) (bool, error) { + payload, err := s.client.Get(ctx, key).Bytes() + if errors.Is(err, redis.Nil) { + return false, nil + } + if err != nil { + return false, err + } + if err := json.Unmarshal(payload, dest); err != nil { + return false, err + } + return true, nil +} + +func (s *RedisStore) Set(ctx context.Context, key string, value any, ttl time.Duration) error { + payload, err := json.Marshal(value) + if err != nil { + return err + } + return s.client.Set(ctx, key, payload, ttl).Err() +} + +func (s *RedisStore) Update(ctx context.Context, key string, ttl time.Duration, update func([]byte) ([]byte, error)) (bool, error) { + for { + found := true + err := s.client.Watch(ctx, func(tx *redis.Tx) error { + payload, err := tx.Get(ctx, key).Bytes() + if errors.Is(err, redis.Nil) { + found = false + return nil + } + if err != nil { + return err + } + updated, err := update(payload) + if err != nil { + return err + } + _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + expiration := ttl + if expiration <= 0 { + expiration = redis.KeepTTL + } + pipe.Set(ctx, key, updated, expiration) + return nil + }) + return err + }, key) + if errors.Is(err, redis.TxFailedErr) { + continue + } + return found, err + } +} + +func (s *RedisStore) Delete(ctx context.Context, key string) error { + return s.client.Del(ctx, key).Err() +} diff --git a/internal/kv/kv_test.go b/internal/kv/kv_test.go new file mode 100644 index 0000000..37f639a --- /dev/null +++ b/internal/kv/kv_test.go @@ -0,0 +1,86 @@ +package kv + +import ( + "context" + "encoding/json" + "testing" + "time" +) + +func TestMemoryStoreGetSetUpdateDelete(t *testing.T) { + store := NewMemoryStore() + ctx := context.Background() + type value struct { + Name string `json:"name"` + Seen bool `json:"seen"` + } + if err := store.Set(ctx, "k", value{Name: "session"}, time.Minute); err != nil { + t.Fatal(err) + } + var got value + found, err := store.Get(ctx, "k", &got) + if err != nil { + t.Fatal(err) + } + if !found || got.Name != "session" { + t.Fatalf("unexpected get result found=%t got=%+v", found, got) + } + updated, err := store.Update(ctx, "k", time.Minute, func(raw []byte) ([]byte, error) { + var v value + if err := json.Unmarshal(raw, &v); err != nil { + return nil, err + } + v.Seen = true + return json.Marshal(v) + }) + if err != nil { + t.Fatal(err) + } + if !updated { + t.Fatal("expected update to find key") + } + found, err = store.Get(ctx, "k", &got) + if err != nil { + t.Fatal(err) + } + if !found || !got.Seen { + t.Fatalf("expected updated value, found=%t got=%+v", found, got) + } + if err := store.Delete(ctx, "k"); err != nil { + t.Fatal(err) + } + found, err = store.Get(ctx, "k", &got) + if err != nil { + t.Fatal(err) + } + if found { + t.Fatal("expected deleted key to be absent") + } +} + +func TestMemoryStoreExpiresKeys(t *testing.T) { + store := NewMemoryStore() + ctx := context.Background() + if err := store.Set(ctx, "k", map[string]string{"v": "1"}, time.Nanosecond); err != nil { + t.Fatal(err) + } + time.Sleep(time.Millisecond) + var got map[string]string + found, err := store.Get(ctx, "k", &got) + if err != nil { + t.Fatal(err) + } + if found { + t.Fatal("expected key to expire") + } +} + +func TestNewStoreCreatesRedisStore(t *testing.T) { + store, err := NewStore("redis://localhost:6379/0") + if err != nil { + t.Fatal(err) + } + if _, ok := store.(*RedisStore); !ok { + t.Fatalf("expected RedisStore, got %T", store) + } +} diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 908542b..60455c4 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -532,8 +532,12 @@ func (c *Client) invokeHTTP(ctx context.Context, upstream Upstream, tool string, if err != nil { return InvokeResult{}, err } + resultBody, err := decodeRPCPayload(result) + if err != nil { + return InvokeResult{}, err + } var rpcResp rpcResponse - if err := json.Unmarshal(result.Body, &rpcResp); err != nil { + if err := json.Unmarshal(resultBody, &rpcResp); err != nil { return InvokeResult{}, err } if len(rpcResp.Error) > 0 && string(rpcResp.Error) != "null" { @@ -575,8 +579,12 @@ func (c *Client) listToolsHTTP(ctx context.Context, upstream Upstream) ([]Tool, if err != nil { return nil, err } + resultBody, err := decodeRPCPayload(result) + if err != nil { + return nil, err + } var rpcResp rpcResponse - if err := json.Unmarshal(result.Body, &rpcResp); err != nil { + if err := json.Unmarshal(resultBody, &rpcResp); err != nil { return nil, err } if len(rpcResp.Error) > 0 && string(rpcResp.Error) != "null" { @@ -665,11 +673,12 @@ func (c *Client) doHTTPEnvelope(ctx context.Context, upstream Upstream, body []b return ForwardResult{StatusCode: resp.StatusCode, Body: bodyBytes, ContentType: contentType, ProtocolVersion: resp.Header.Get("MCP-Protocol-Version"), SessionExpired: true, SessionID: sessionID}, nil } if strings.Contains(contentType, "text/event-stream") { - data, sseErr := extractFirstSSEData(resp.Body) - if sseErr != nil { - return ForwardResult{}, sseErr + respBody := new(bytes.Buffer) + _, err = respBody.ReadFrom(resp.Body) + if err != nil { + return ForwardResult{}, err } - return ForwardResult{StatusCode: resp.StatusCode, Body: data, ContentType: "application/json", ProtocolVersion: resp.Header.Get("MCP-Protocol-Version"), SessionID: sessionID}, nil + return ForwardResult{StatusCode: resp.StatusCode, Body: respBody.Bytes(), ContentType: contentType, ProtocolVersion: resp.Header.Get("MCP-Protocol-Version"), SessionID: sessionID}, nil } respBody := new(bytes.Buffer) _, err = respBody.ReadFrom(resp.Body) @@ -858,6 +867,13 @@ func (c *Client) initializeHTTPSession(ctx context.Context, upstream Upstream, p return hadSession, nil } +func decodeRPCPayload(result ForwardResult) ([]byte, error) { + if strings.Contains(strings.ToLower(result.ContentType), "text/event-stream") { + return extractFirstSSEData(bytes.NewReader(result.Body)) + } + return result.Body, nil +} + func (c *Client) invokeStdio(ctx context.Context, upstream Upstream, tool string, input map[string]any) (InvokeResult, error) { if upstream.Command == "" { return InvokeResult{}, fmt.Errorf("stdio upstream %q missing command", upstream.Name)