From f711c15a2be1da1e13dcc1925339ee528f19f34a Mon Sep 17 00:00:00 2001 From: "dorey-agent[bot]" <3504508+dorey-agent[bot]@users.noreply.github.com> Date: Mon, 8 Jun 2026 08:35:01 +0000 Subject: [PATCH 1/8] feat: add gateway route handler registration for plugins - Add routes contribution type to plugin schema (contributes.gateway.routes) - Routes are auto-namespaced under /plugins/{plugin-name}/ to prevent conflicts - Generator validates route path conflicts at generate time - Add gateway public_url config field for OAuth callbacks - Extend gateway SDK with RegisterRoute/MatchRoute + middleware Abort API - Add toJSON template function to middleware/route rendering mcp-oauth plugin updates: - Remove redundant token_file option (derive from token_dir + provider name) - Add OAuth callback handler at /plugins/mcp-oauth/callback - Middleware returns 401 + authorize_url when no token exists - Token files now per-provider: {token_dir}/{provider}.json --- core/plugins/mcp-oauth/README.md | 60 ++++-- core/plugins/mcp-oauth/handlers/callback.go | 223 ++++++++++++++++++++ core/plugins/mcp-oauth/middlewares/oauth.go | 85 ++++++-- core/plugins/mcp-oauth/plugin.yaml | 10 +- core/sdk/gateway/middleware.go | 56 +++++ core/sdk/gateway/middleware_test.go | 77 ++++++- docs/configuration.md | 1 + docs/guides/creating-plugins.md | 40 ++++ docs/plugins.md | 18 +- internal/config/config.go | 3 +- internal/generate/v1/gateway_config.go | 15 +- internal/generate/v1/gateway_config_test.go | 12 ++ internal/generate/v1/generator.go | 14 ++ internal/generate/v1/middleware_copy.go | 13 +- internal/generate/v1/routes.go | 152 +++++++++++++ internal/generate/v1/routes_test.go | 178 ++++++++++++++++ internal/plugin/merge.go | 1 + internal/plugin/types.go | 8 + 18 files changed, 920 insertions(+), 46 deletions(-) create mode 100644 core/plugins/mcp-oauth/handlers/callback.go create mode 100644 internal/generate/v1/routes.go create mode 100644 internal/generate/v1/routes_test.go diff --git a/core/plugins/mcp-oauth/README.md b/core/plugins/mcp-oauth/README.md index a12e07b..ae7b3aa 100644 --- a/core/plugins/mcp-oauth/README.md +++ b/core/plugins/mcp-oauth/README.md @@ -1,45 +1,69 @@ # mcp-oauth -Provides OAuth token storage for MCP (Model Context Protocol) providers via a shared volume between the gateway and agent. +Provides full OAuth lifecycle for MCP (Model Context Protocol) providers: automatic token injection, refresh, and browser-based authorization via gateway callback. ## How It Works -Declares a named volume (`oauth-tokens`) mounted into both the gateway and agent containers at a configurable path. MCP providers that require OAuth can store and read token files from this shared location. - -> **Note:** This plugin currently only handles the volume declaration. You must manually declare gateway services for each MCP provider endpoint in your `agent.yaml`. Full automation (dynamic service entries from provider URLs) is planned. +1. **Middleware** intercepts requests to configured domains. If a valid token exists, injects `Authorization: Bearer `. If no token exists, returns 401 with an `authorize_url` for the user to click. +2. **Callback handler** at `/plugins/mcp-oauth/callback` receives the OAuth authorization code, exchanges it for tokens, and writes the token file to the shared volume. +3. **Shared volume** (`oauth-tokens`) is mounted into both gateway and agent containers so the MCP client can read tokens written by the gateway. ## Usage ```yaml # agent.yaml +gateway: + public_url: "https://gateway.myagent.example.com" + services: + - url: https://mcp.notion.com + installations: - plugin: "@builtin/mcp-oauth" options: providers: notion: mcp_url: https://mcp.notion.com/mcp + authorize_endpoint: https://api.notion.com/v1/oauth/authorize + token_endpoint: https://api.notion.com/v1/oauth/token + client_id: "your-client-id" + client_secret: "your-client-secret" + scopes: "read_content" token_dir: "/data/oauth-tokens" - -gateway: - services: - - url: https://mcp.notion.com - headers: - Authorization: Bearer ${NOTION_TOKEN} -``` - -```bash -# .env -NOTION_TOKEN=ntn_xxxx ``` ## Options | Option | Type | Required | Default | Description | |--------|------|----------|---------|-------------| -| `providers` | object | yes | — | Map of provider name to config. Each provider needs at least `mcp_url`. | +| `providers` | object | yes | — | Map of provider name to OAuth config | | `token_dir` | string | no | `/data/oauth-tokens` | Directory for OAuth token files | +### Provider Config + +Each provider entry supports: + +| Field | Required | Description | +|-------|----------|-------------| +| `mcp_url` | yes | MCP server endpoint | +| `authorize_endpoint` | yes | OAuth authorize URL | +| `token_endpoint` | yes | OAuth token exchange URL | +| `client_id` | yes | OAuth client ID | +| `client_secret` | no | OAuth client secret | +| `scopes` | no | Space-separated scopes | + ## What It Contributes -- **Gateway:** Shared volume `oauth-tokens` mounted at `token_dir` -- **Agent:** Same volume accessible for MCP client token reads +- **Gateway middleware:** Token injection + 401 with authorize URL when unauthenticated +- **Gateway route:** `/plugins/mcp-oauth/callback` — OAuth code exchange handler +- **Gateway volume:** Shared `oauth-tokens` volume at `token_dir` + +## OAuth Flow + +``` +1. Agent MCP client → request to notion domain +2. Gateway middleware: no token file → returns 401 + authorize_url +3. User clicks authorize_url → Notion login page +4. Notion redirects → https://gateway.example.com/plugins/mcp-oauth/callback?code=X&state=notion +5. Gateway callback handler: exchanges code → writes /data/oauth-tokens/notion.json +6. Next request → middleware reads token → injects Bearer header → proxied to Notion +``` diff --git a/core/plugins/mcp-oauth/handlers/callback.go b/core/plugins/mcp-oauth/handlers/callback.go new file mode 100644 index 0000000..6b886f8 --- /dev/null +++ b/core/plugins/mcp-oauth/handlers/callback.go @@ -0,0 +1,223 @@ +package custom + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "strings" + "time" + + "github.com/donbader/agent-sandbox/core/sdk/gateway" +) + +// oauthProviderConfig holds OAuth provider settings baked in at generate time. +type oauthProviderConfig struct { + TokenEndpoint string + ClientID string + ClientSecret string + MCP_URL string +} + +var oauthCallbackProviders = map[string]oauthProviderConfig{} +var oauthCallbackTokenDir string + +func init() { + oauthCallbackTokenDir = "{{ .options.token_dir }}" + providersJSON := `{{ toJSON .options.providers }}` + var providers map[string]map[string]any + if err := json.Unmarshal([]byte(providersJSON), &providers); err == nil { + for name, cfg := range providers { + p := oauthProviderConfig{} + if v, ok := cfg["token_endpoint"].(string); ok { + p.TokenEndpoint = v + } + if v, ok := cfg["client_id"].(string); ok { + p.ClientID = v + } + if v, ok := cfg["client_secret"].(string); ok { + p.ClientSecret = v + } + if v, ok := cfg["mcp_url"].(string); ok { + p.MCP_URL = v + } + oauthCallbackProviders[name] = p + } + } + + gateway.RegisterRoute(gateway.RouteDef{ + Path: "{{ .path }}", + Handler: handleOAuthCallback, + }) +} + +func handleOAuthCallback(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") // state = provider name + + if code == "" { + http.Error(w, "missing code parameter", http.StatusBadRequest) + return + } + if state == "" { + http.Error(w, "missing state parameter (provider name)", http.StatusBadRequest) + return + } + + provider, ok := oauthCallbackProviders[state] + if !ok { + http.Error(w, fmt.Sprintf("unknown provider: %s", state), http.StatusBadRequest) + return + } + + if provider.TokenEndpoint == "" { + http.Error(w, fmt.Sprintf("provider %s has no token_endpoint configured", state), http.StatusInternalServerError) + return + } + + // Exchange authorization code for token + redirectURI := r.URL.Query().Get("redirect_uri") + if redirectURI == "" { + // Reconstruct from request + scheme := "https" + if r.TLS == nil { + scheme = "http" + } + redirectURI = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.URL.Path) + } + + token, err := exchangeCode(provider, code, redirectURI) + if err != nil { + http.Error(w, fmt.Sprintf("token exchange failed: %v", err), http.StatusInternalServerError) + return + } + + // Write token file + tokenFile := oauthCallbackTokenDir + "/" + state + ".json" + if err := writeCallbackToken(tokenFile, token, provider); err != nil { + http.Error(w, fmt.Sprintf("failed to save token: %v", err), http.StatusInternalServerError) + return + } + + // Register the new access token as a secret for log redaction + gateway.RegisterSecret(token.AccessToken) + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, ` +

Authorization successful

+

Provider %s has been connected. You can close this tab.

+`, state) +} + +type callbackTokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` +} + +func exchangeCode(provider oauthProviderConfig, code, redirectURI string) (*callbackTokenResponse, error) { + u, err := url.Parse(provider.TokenEndpoint) + if err != nil { + return nil, fmt.Errorf("invalid token_endpoint URL: %w", err) + } + if u.Scheme != "https" { + return nil, fmt.Errorf("token_endpoint must use https, got %q", u.Scheme) + } + + params := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "client_id": {provider.ClientID}, + "redirect_uri": {redirectURI}, + } + if provider.ClientSecret != "" { + params.Set("client_secret", provider.ClientSecret) + } + + client := &http.Client{ + Timeout: 30 * time.Second, + Transport: callbackSSRFSafeTransport(), + } + + resp, err := client.Post( + provider.TokenEndpoint, + "application/x-www-form-urlencoded", + strings.NewReader(params.Encode()), + ) + if err != nil { + return nil, fmt.Errorf("token request to %s: %w", provider.TokenEndpoint, err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("reading token response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token endpoint returned %d: %s", resp.StatusCode, string(body)) + } + + var tr callbackTokenResponse + if err := json.Unmarshal(body, &tr); err != nil { + return nil, fmt.Errorf("parsing token response: %w", err) + } + return &tr, nil +} + +func writeCallbackToken(path string, token *callbackTokenResponse, provider oauthProviderConfig) error { + expiresIn := token.ExpiresIn + if expiresIn == 0 { + expiresIn = 3600 + } + + stored := map[string]any{ + "access_token": token.AccessToken, + "expires_at": time.Now().Unix() + expiresIn, + "token_endpoint": provider.TokenEndpoint, + "client_id": provider.ClientID, + } + if token.RefreshToken != "" { + stored["refresh_token"] = token.RefreshToken + } + if provider.ClientSecret != "" { + stored["client_secret"] = provider.ClientSecret + } + + data, err := json.MarshalIndent(stored, "", " ") + if err != nil { + return err + } + + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, 0600); err != nil { + return err + } + return os.Rename(tmp, path) +} + +func callbackSSRFSafeTransport() *http.Transport { + return &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("invalid address %q: %w", addr, err) + } + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("DNS lookup failed for %q: %w", host, err) + } + for _, ip := range ips { + if ip.IP.IsLoopback() || ip.IP.IsPrivate() || ip.IP.IsLinkLocalUnicast() || ip.IP.IsLinkLocalMulticast() { + return nil, fmt.Errorf("refusing to connect to private IP %s (resolved from %s)", ip.IP, host) + } + } + dialer := &net.Dialer{Timeout: 10 * time.Second} + return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port)) + }, + } +} diff --git a/core/plugins/mcp-oauth/middlewares/oauth.go b/core/plugins/mcp-oauth/middlewares/oauth.go index 6b15299..739d73b 100644 --- a/core/plugins/mcp-oauth/middlewares/oauth.go +++ b/core/plugins/mcp-oauth/middlewares/oauth.go @@ -27,8 +27,15 @@ type storedToken struct { ClientSecret *string `json:"client_secret"` } +type oauthProviderInfo struct { + AuthorizeEndpoint string + ClientID string + Scopes string +} + type oauthState struct { - tokenFile string + tokenDir string + providers map[string]oauthProviderInfo mu sync.Mutex cachedToken *storedToken cachedUntil time.Time @@ -36,11 +43,38 @@ type oauthState struct { } func init() { - tokenFile := "{{ .options.token_file }}" + tokenDir := "{{ .options.token_dir }}" domains := strings.Split("{{ .domainsList }}", ",") + providersJSON := `{{ toJSON .options.providers }}` + providers := make(map[string]oauthProviderInfo) + var rawProviders map[string]map[string]any + if err := json.Unmarshal([]byte(providersJSON), &rawProviders); err == nil { + for name, cfg := range rawProviders { + p := oauthProviderInfo{} + if v, ok := cfg["authorize_endpoint"].(string); ok { + p.AuthorizeEndpoint = v + } + if v, ok := cfg["client_id"].(string); ok { + p.ClientID = v + } + if v, ok := cfg["scopes"].(string); ok { + p.Scopes = v + } + providers[name] = p + } + } + + var defaultProvider string + for name := range providers { + defaultProvider = name + break + } + + tokenFile := tokenDir + "/" + defaultProvider + ".json" state := &oauthState{ - tokenFile: tokenFile, + tokenDir: tokenDir, + providers: providers, httpClient: &http.Client{ Timeout: 30 * time.Second, Transport: oauthSSRFSafeTransport(), @@ -58,10 +92,19 @@ func init() { Name: "oauth:" + domains[0], Domains: domains, Func: func(ctx *gateway.MiddlewareContext) error { - token, err := state.getValidToken() + token, err := state.getValidToken(tokenFile) if err != nil { if errors.Is(err, os.ErrNotExist) { - slog.Debug("oauth: token file not found", "file", state.tokenFile) + if provider, ok := state.providers[defaultProvider]; ok && provider.AuthorizeEndpoint != "" { + authorizeURL := buildAuthorizeURL(provider, defaultProvider) + ctx.SetAbortHeader("X-OAuth-Authorize-URL", authorizeURL) + ctx.SetAbortHeader("Content-Type", "application/json") + ctx.Abort(http.StatusUnauthorized, fmt.Sprintf( + `{"error":"oauth_required","provider":%q,"authorize_url":%q}`, + defaultProvider, authorizeURL)) + return nil + } + slog.Debug("oauth: token file not found", "file", tokenFile) } else { slog.Error("oauth: failed to get token", "error", err) } @@ -73,6 +116,18 @@ func init() { }) } +func buildAuthorizeURL(provider oauthProviderInfo, providerName string) string { + params := url.Values{ + "client_id": {provider.ClientID}, + "response_type": {"code"}, + "state": {providerName}, + } + if provider.Scopes != "" { + params.Set("scope", provider.Scopes) + } + return provider.AuthorizeEndpoint + "?" + params.Encode() +} + func oauthSecrets(tokenFile string) []string { data, err := os.ReadFile(tokenFile) if err != nil { @@ -88,13 +143,13 @@ func oauthSecrets(tokenFile string) []string { return nil } -func (s *oauthState) getValidToken() (string, error) { +func (s *oauthState) getValidToken(tokenFile string) (string, error) { s.mu.Lock() defer s.mu.Unlock() if s.cachedToken != nil && time.Now().Before(s.cachedUntil) { return s.cachedToken.AccessToken, nil } - stored, err := s.readTokenFile() + stored, err := s.readTokenFile(tokenFile) if err != nil { return "", err } @@ -105,7 +160,7 @@ func (s *oauthState) getValidToken() (string, error) { return "", fmt.Errorf("token refresh failed: %w", err) } stored = refreshed - if err := s.writeTokenFile(stored); err != nil { + if err := s.writeTokenFile(tokenFile, stored); err != nil { slog.Error("oauth: failed to write refreshed token", "error", err) } } @@ -119,28 +174,28 @@ func (s *oauthState) getValidToken() (string, error) { return stored.AccessToken, nil } -func (s *oauthState) readTokenFile() (*storedToken, error) { - data, err := os.ReadFile(s.tokenFile) +func (s *oauthState) readTokenFile(tokenFile string) (*storedToken, error) { + data, err := os.ReadFile(tokenFile) if err != nil { - return nil, fmt.Errorf("reading token file %s: %w", s.tokenFile, err) + return nil, fmt.Errorf("reading token file %s: %w", tokenFile, err) } var token storedToken if err := json.Unmarshal(data, &token); err != nil { - return nil, fmt.Errorf("parsing token file %s: %w", s.tokenFile, err) + return nil, fmt.Errorf("parsing token file %s: %w", tokenFile, err) } return &token, nil } -func (s *oauthState) writeTokenFile(token *storedToken) error { +func (s *oauthState) writeTokenFile(tokenFile string, token *storedToken) error { data, err := json.MarshalIndent(token, "", " ") if err != nil { return err } - tmp := s.tokenFile + ".tmp" + tmp := tokenFile + ".tmp" if err := os.WriteFile(tmp, data, 0600); err != nil { return err } - return os.Rename(tmp, s.tokenFile) + return os.Rename(tmp, tokenFile) } type oauthTokenResponse struct { diff --git a/core/plugins/mcp-oauth/plugin.yaml b/core/plugins/mcp-oauth/plugin.yaml index a7f8970..224f985 100644 --- a/core/plugins/mcp-oauth/plugin.yaml +++ b/core/plugins/mcp-oauth/plugin.yaml @@ -3,17 +3,12 @@ options: providers: type: object required: true - description: "Map of provider name to MCP config" + description: "Map of provider name to MCP config (each needs mcp_url, client_id, client_secret)" token_dir: type: string required: false default: "/data/oauth-tokens" description: "Directory for OAuth token files" - token_file: - type: string - required: false - default: "/data/oauth-tokens/token.json" - description: "Path to OAuth token JSON file" contributes: gateway: @@ -21,3 +16,6 @@ contributes: - "oauth-tokens:{{ .plugin.options.token_dir }}" middlewares: - custom: "./middlewares/oauth.go" + routes: + - path: "/callback" + handler: "./handlers/callback.go" diff --git a/core/sdk/gateway/middleware.go b/core/sdk/gateway/middleware.go index 3fe59bc..cadb75f 100644 --- a/core/sdk/gateway/middleware.go +++ b/core/sdk/gateway/middleware.go @@ -3,12 +3,33 @@ package gateway import ( "net" "net/http" + "strings" ) // MiddlewareContext provides request access and environment resolution for custom middleware. type MiddlewareContext struct { Request *http.Request Env func(string) string + + // Abort fields: if AbortStatus is set (non-zero), the gateway returns this + // response instead of proxying the request to the upstream. + AbortStatus int + AbortHeaders http.Header + AbortBody string +} + +// Abort sets the context to return an HTTP response instead of proxying. +func (c *MiddlewareContext) Abort(status int, body string) { + c.AbortStatus = status + c.AbortBody = body +} + +// SetAbortHeader sets a header on the abort response. +func (c *MiddlewareContext) SetAbortHeader(key, value string) { + if c.AbortHeaders == nil { + c.AbortHeaders = make(http.Header) + } + c.AbortHeaders.Set(key, value) } // MiddlewareFunc is the signature for custom gateway middleware. @@ -80,4 +101,39 @@ func Secrets() []string { func ResetForTesting() { registry = nil secrets = nil + routeRegistry = nil +} + +// RouteHandlerFunc is the signature for custom gateway route handlers. +// Unlike middleware (which intercepts proxy requests), route handlers serve +// direct HTTP requests to the gateway on registered paths. +type RouteHandlerFunc func(w http.ResponseWriter, r *http.Request) + +// RouteDef defines a route handler with a path pattern. +type RouteDef struct { + // Path is the full namespaced path (e.g. /plugins/mcp-oauth/notion/callback). + Path string + Handler RouteHandlerFunc +} + +var routeRegistry []RouteDef + +// RegisterRoute registers a direct HTTP route handler on the gateway. +func RegisterRoute(def RouteDef) { + routeRegistry = append(routeRegistry, def) +} + +// Routes returns all registered route definitions. +func Routes() []RouteDef { + return routeRegistry +} + +// MatchRoute returns the handler for the given request path, or nil if no route matches. +func MatchRoute(path string) RouteHandlerFunc { + for _, route := range routeRegistry { + if path == route.Path || strings.HasPrefix(path, route.Path+"/") { + return route.Handler + } + } + return nil } diff --git a/core/sdk/gateway/middleware_test.go b/core/sdk/gateway/middleware_test.go index 36368a8..63db982 100644 --- a/core/sdk/gateway/middleware_test.go +++ b/core/sdk/gateway/middleware_test.go @@ -1,6 +1,9 @@ package gateway -import "testing" +import ( + "net/http" + "testing" +) func TestRegisterSecret(t *testing.T) { // Reset state @@ -21,3 +24,75 @@ func TestRegisterSecret(t *testing.T) { t.Errorf("expected 'another-secret', got %q", got[1]) } } + +func TestRegisterRoute(t *testing.T) { + ResetForTesting() + + handler := func(w http.ResponseWriter, r *http.Request) {} + RegisterRoute(RouteDef{Path: "/plugins/mcp-oauth/callback", Handler: handler}) + RegisterRoute(RouteDef{Path: "/plugins/other/webhook", Handler: handler}) + + routes := Routes() + if len(routes) != 2 { + t.Fatalf("expected 2 routes, got %d", len(routes)) + } + if routes[0].Path != "/plugins/mcp-oauth/callback" { + t.Errorf("expected '/plugins/mcp-oauth/callback', got %q", routes[0].Path) + } + if routes[1].Path != "/plugins/other/webhook" { + t.Errorf("expected '/plugins/other/webhook', got %q", routes[1].Path) + } +} + +func TestMatchRoute(t *testing.T) { + ResetForTesting() + + RegisterRoute(RouteDef{ + Path: "/plugins/mcp-oauth/callback", + Handler: func(w http.ResponseWriter, r *http.Request) {}, + }) + + // Exact match + h := MatchRoute("/plugins/mcp-oauth/callback") + if h == nil { + t.Fatal("expected handler for exact path match") + } + + // Prefix match (sub-path) + h = MatchRoute("/plugins/mcp-oauth/callback/extra") + if h == nil { + t.Fatal("expected handler for prefix path match") + } + + // No match + h = MatchRoute("/plugins/other/callback") + if h != nil { + t.Fatal("expected nil for non-matching path") + } + + // No match (partial prefix without /) + h = MatchRoute("/plugins/mcp-oauth/callbackextra") + if h != nil { + t.Fatal("expected nil for partial prefix without /") + } +} + +func TestMiddlewareContext_Abort(t *testing.T) { + ctx := &MiddlewareContext{ + Request: &http.Request{}, + Env: func(s string) string { return "" }, + } + + ctx.SetAbortHeader("X-Test", "value") + ctx.Abort(401, `{"error":"unauthorized"}`) + + if ctx.AbortStatus != 401 { + t.Errorf("expected status 401, got %d", ctx.AbortStatus) + } + if ctx.AbortBody != `{"error":"unauthorized"}` { + t.Errorf("unexpected body: %s", ctx.AbortBody) + } + if ctx.AbortHeaders.Get("X-Test") != "value" { + t.Errorf("expected X-Test header, got %q", ctx.AbortHeaders.Get("X-Test")) + } +} diff --git a/docs/configuration.md b/docs/configuration.md index dd5945f..9a58d07 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -50,6 +50,7 @@ runtime: - "./local:/home/agent/local" gateway: + public_url: string # optional — public URL for OAuth callbacks and webhook receivers services: # optional — external services proxied through the gateway - url: https://api.example.com network: string # optional — compose network to attach diff --git a/docs/guides/creating-plugins.md b/docs/guides/creating-plugins.md index 8389c6f..0efac82 100644 --- a/docs/guides/creating-plugins.md +++ b/docs/guides/creating-plugins.md @@ -100,10 +100,50 @@ contributes: | `runtime.volumes` | Volume mount specs for docker-compose | | `gateway.services` | Services the gateway intercepts. Each entry has a `url` and a list of `middlewares` | | `gateway.services[].middlewares[].custom` | Path to a Go middleware file, relative to the plugin directory | +| `gateway.routes` | HTTP route handlers registered on the gateway (see below) | | `sidecar.services` | Additional Docker containers that run alongside the agent (see below) | All `contributes` fields support Go `text/template` syntax. See **Template Context** below for available variables and functions. +### `contributes.gateway.routes` + +Registers HTTP route handlers directly on the gateway. Routes are automatically namespaced under `/plugins/{plugin-name}/` to prevent conflicts between plugins. + +```yaml +contributes: + gateway: + routes: + - path: "/callback" + handler: "./handlers/callback.go" +``` + +| Field | Description | +|-------|-------------| +| `path` | Relative path (prefixed automatically with `/plugins/{plugin-name}`) | +| `handler` | Path to a Go handler file, relative to the plugin directory | + +The handler file is a Go template (same as middleware) with access to `{{ .path }}` (the full namespaced path) and `{{ .options }}`. It must self-register using `gateway.RegisterRoute()` in an `init()` function: + +```go +package custom + +import ( + "net/http" + "github.com/donbader/agent-sandbox/core/sdk/gateway" +) + +func init() { + gateway.RegisterRoute(gateway.RouteDef{ + Path: "{{ .path }}", + Handler: myHandler, + }) +} + +func myHandler(w http.ResponseWriter, r *http.Request) { + // Handle the request +} +``` + ### `contributes.sidecar.services` Defines additional Docker containers that run alongside the agent. Each sidecar is a separate compose service. Sidecars automatically get `depends_on: { agent: { condition: service_healthy } }` so they start only after the agent is healthy. diff --git a/docs/plugins.md b/docs/plugins.md index 4cbc010..7c35ba9 100644 --- a/docs/plugins.md +++ b/docs/plugins.md @@ -93,24 +93,36 @@ Key-only auth. No passwords. Connect with `ssh -p 2222 agent@localhost`. ### @builtin/mcp-oauth -OAuth token storage for MCP (Model Context Protocol) providers via a shared volume. +Full OAuth lifecycle for MCP providers: token injection, automatic refresh, and browser-based authorization via gateway callback. ```yaml +gateway: + public_url: "https://gateway.myagent.example.com" + services: + - url: https://mcp.notion.com + installations: - plugin: "@builtin/mcp-oauth" options: providers: notion: mcp_url: https://mcp.notion.com/mcp + authorize_endpoint: https://api.notion.com/v1/oauth/authorize + token_endpoint: https://api.notion.com/v1/oauth/token + client_id: "your-client-id" + client_secret: "your-client-secret" + scopes: "read_content" token_dir: "/data/oauth-tokens" ``` | Option | Type | Required | Default | Description | |--------|------|----------|---------|-------------| -| `providers` | object | yes | — | Map of provider name to MCP config (`mcp_url` required per provider) | +| `providers` | object | yes | — | Map of provider name to OAuth config | | `token_dir` | string | no | `/data/oauth-tokens` | Directory for OAuth token files | -You must also declare the provider endpoints as gateway services in your `agent.yaml`. +**Flow:** When no token exists, the middleware returns 401 with an `authorize_url`. The user visits the URL, authorizes, and the OAuth provider redirects to `/plugins/mcp-oauth/callback` on the gateway. The callback handler exchanges the code for tokens and writes them to the shared volume. Subsequent requests are automatically authenticated. + +**Contributes:** gateway middleware (token injection + 401), gateway route (`/plugins/mcp-oauth/callback`), shared volume. --- diff --git a/internal/config/config.go b/internal/config/config.go index b87c1fc..013204d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -67,7 +67,8 @@ type RuntimeConfig struct { // GatewayConfig holds gateway proxy configuration. type GatewayConfig struct { - Services []GatewayServiceEntry `yaml:"services" json:"services,omitempty" jsonschema:"title=services,description=External services proxied through the gateway"` + PublicURL string `yaml:"public_url" json:"public_url,omitempty" jsonschema:"title=public_url,description=Public URL of the gateway (used for OAuth callbacks and webhook receivers)"` + Services []GatewayServiceEntry `yaml:"services" json:"services,omitempty" jsonschema:"title=services,description=External services proxied through the gateway"` } // GatewayServiceEntry represents an allowed upstream service. diff --git a/internal/generate/v1/gateway_config.go b/internal/generate/v1/gateway_config.go index 854583b..9789a1c 100644 --- a/internal/generate/v1/gateway_config.go +++ b/internal/generate/v1/gateway_config.go @@ -19,6 +19,8 @@ type GatewayConfigOutput struct { Services []GatewayServiceOutput Middlewares []MiddlewareRef // custom .go files to copy with domain scope AuthHeaders []AuthHeaderEntry // auth-header entries to generate as .go files + Routes []RouteRef // route handler .go files with namespaced paths + PublicURL string // gateway public URL for callbacks } // MiddlewareRef associates a custom middleware file with its target domains. @@ -27,6 +29,13 @@ type MiddlewareRef struct { Domains []string // domains this middleware applies to } +// RouteRef associates a route handler file with its namespaced path. +type RouteRef struct { + Path string // namespaced URL path (e.g. /plugins/mcp-oauth/callback) + Handler string // path to handler .go file + PluginName string // plugin that contributed this route +} + // AuthHeaderEntry describes an auth-header middleware to generate at build time. type AuthHeaderEntry struct { Domain string @@ -48,11 +57,14 @@ type gatewayRuntimeConfig struct { DNSListen string `yaml:"dns_listen"` MITMDomains []string `yaml:"mitm_domains"` HealthAddr string `yaml:"health_addr,omitempty"` + PublicURL string `yaml:"public_url,omitempty"` } // BuildGatewayConfig merges user gateway config with plugin contributions. func BuildGatewayConfig(cfg *config.Config, contribs *plugin.Contributions) *GatewayConfigOutput { - out := &GatewayConfigOutput{} + out := &GatewayConfigOutput{ + PublicURL: cfg.Gateway.PublicURL, + } // User-declared services for _, svc := range cfg.Gateway.Services { @@ -120,6 +132,7 @@ func WriteGatewayRuntimeConfig(buildDir string, gwCfg *GatewayConfigOutput) erro rc := gatewayRuntimeConfig{ Listen: ":8443", DNSListen: ":53", + PublicURL: gwCfg.PublicURL, } for _, svc := range gwCfg.Services { diff --git a/internal/generate/v1/gateway_config_test.go b/internal/generate/v1/gateway_config_test.go index fb8bc60..0066177 100644 --- a/internal/generate/v1/gateway_config_test.go +++ b/internal/generate/v1/gateway_config_test.go @@ -72,3 +72,15 @@ func TestBuildGatewayConfig_NilContribs(t *testing.T) { gwCfg := BuildGatewayConfig(cfg, nil) require.Len(t, gwCfg.Services, 1) } + +func TestBuildGatewayConfig_PublicURL(t *testing.T) { + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + PublicURL: "https://gateway.example.com", + Services: []config.GatewayServiceEntry{{URL: "https://api.example.com"}}, + }, + } + + gwCfg := BuildGatewayConfig(cfg, nil) + assert.Equal(t, "https://gateway.example.com", gwCfg.PublicURL) +} diff --git a/internal/generate/v1/generator.go b/internal/generate/v1/generator.go index 824855e..662e501 100644 --- a/internal/generate/v1/generator.go +++ b/internal/generate/v1/generator.go @@ -271,6 +271,14 @@ func (g *Generator) generateAgent(cfg *config.Config, agentDir, buildDir string) } gwCfg := BuildGatewayConfig(cfg, merged) + + // Collect routes from each plugin with namespace prefixing + routeRefs, err := g.collectPluginRoutes(resolved, buildDir) + if err != nil { + return nil, err + } + gwCfg.Routes = routeRefs + if err := WriteGatewayRuntimeConfig(buildDir, gwCfg); err != nil { return nil, fmt.Errorf("write gateway runtime config: %w", err) } @@ -285,6 +293,12 @@ func (g *Generator) generateAgent(cfg *config.Config, agentDir, buildDir string) return nil, fmt.Errorf("generate auth-header middleware: %w", err) } } + if len(gwCfg.Routes) > 0 { + allOpts := collectAllOptions(cfg) + if err := CopyRouteHandlers(g.projectDir, buildDir, gwCfg.Routes, allOpts); err != nil { + return nil, fmt.Errorf("copy route handlers: %w", err) + } + } return &AgentResult{Config: cfg, Contribs: merged, BuildDir: buildDir}, nil } diff --git a/internal/generate/v1/middleware_copy.go b/internal/generate/v1/middleware_copy.go index 4980fda..cc1c047 100644 --- a/internal/generate/v1/middleware_copy.go +++ b/internal/generate/v1/middleware_copy.go @@ -3,6 +3,7 @@ package v1 import ( "bytes" "encoding/base64" + "encoding/json" "fmt" "os" "path/filepath" @@ -134,7 +135,17 @@ func renderMiddleware(name, content string, data map[string]any) (string, error) return content, nil } - tmpl, err := template.New(name).Parse(content) + funcMap := template.FuncMap{ + "toJSON": func(v any) (string, error) { + b, err := json.Marshal(v) + if err != nil { + return "", fmt.Errorf("toJSON: %w", err) + } + return string(b), nil + }, + } + + tmpl, err := template.New(name).Funcs(funcMap).Parse(content) if err != nil { return "", fmt.Errorf("parse template: %w", err) } diff --git a/internal/generate/v1/routes.go b/internal/generate/v1/routes.go new file mode 100644 index 0000000..58c10f5 --- /dev/null +++ b/internal/generate/v1/routes.go @@ -0,0 +1,152 @@ +package v1 + +import ( + "bytes" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "text/template" +) + +// collectPluginRoutes iterates over resolved plugins, namespaces their declared routes, +// checks for conflicts, and extracts/resolves handler file paths. +// Returns route refs ready for code generation. +func (g *Generator) collectPluginRoutes(resolved map[string]*resolvedPlugin, buildDir string) ([]RouteRef, error) { + var routes []RouteRef + seen := make(map[string]string) // full path → plugin name (for conflict detection) + + for ref, rp := range resolved { + for _, route := range rp.rendered.Gateway.Routes { + if route.Path == "" || route.Handler == "" { + return nil, fmt.Errorf("plugin %q: route entry must have both path and handler", ref) + } + + // Namespace: /plugins/{plugin-name}{declared-path} + namespacedPath := "/plugins/" + rp.def.Name + normalizePath(route.Path) + + // Conflict detection + if owner, exists := seen[namespacedPath]; exists { + return nil, fmt.Errorf("route conflict: path %q registered by both %q and %q", namespacedPath, owner, ref) + } + seen[namespacedPath] = ref + + // Resolve handler file path + handlerPath, err := g.resolveRouteHandler(rp.def.Name, route.Handler, rp.def.BaseDir, buildDir) + if err != nil { + return nil, fmt.Errorf("plugin %q route %q: %w", ref, route.Path, err) + } + + routes = append(routes, RouteRef{ + Path: namespacedPath, + Handler: handlerPath, + PluginName: rp.def.Name, + }) + } + } + + return routes, nil +} + +// resolveRouteHandler resolves a route handler file path. +// For local plugins (baseDir != ""), it's relative to the plugin directory. +// For bundled plugins, it's extracted from the bundled FS. +func (g *Generator) resolveRouteHandler(pluginName, handler, baseDir, buildDir string) (string, error) { + if baseDir != "" { + return filepath.Join(baseDir, handler), nil + } + + // Bundled plugin — extract handler from FS + if g.bundledFS == nil { + return "", fmt.Errorf("no bundled FS available to extract handler %q", handler) + } + return g.extractBundledMiddleware(pluginName, handler, buildDir) +} + +// CopyRouteHandlers copies route handler .go files into the gateway build context. +// Each handler is a Go template rendered with path and options, and self-registers +// via init() calling gateway.RegisterRoute() — same pattern as custom middleware. +func CopyRouteHandlers(projectDir, outDir string, routes []RouteRef, opts map[string]any) error { + if len(routes) == 0 { + return nil + } + + destDir := filepath.Join(outDir, "gateway-src", "core", "gateway", "middlewares", "custom") + if err := os.MkdirAll(destDir, 0755); err != nil { + return fmt.Errorf("create route handler dest dir: %w", err) + } + + // Resolve ${VAR} references in options to actual env values + resolved := resolveEnvVars(opts) + + for _, route := range routes { + var srcPath string + if filepath.IsAbs(route.Handler) { + srcPath = route.Handler + } else { + srcPath = filepath.Join(projectDir, route.Handler) + } + content, err := os.ReadFile(srcPath) + if err != nil { + return fmt.Errorf("read route handler %s: %w", route.Handler, err) + } + + // Template data includes options and the namespaced path + data := map[string]any{ + "options": resolved, + "path": route.Path, + } + + // Template-render the handler file + rendered, err := renderRouteHandler(srcPath, string(content), data) + if err != nil { + return fmt.Errorf("render route handler %s: %w", route.Handler, err) + } + + filename := fmt.Sprintf("route_%s_%s.go", sanitizeFilename(route.PluginName), sanitizeFilename(filepath.Base(route.Handler))) + destFile := filepath.Join(destDir, filename) + if err := os.WriteFile(destFile, []byte(rendered), 0644); err != nil { + return fmt.Errorf("write route handler %s: %w", destFile, err) + } + } + + return nil +} + +// renderRouteHandler executes Go templates in route handler source code. +func renderRouteHandler(name, content string, data map[string]any) (string, error) { + if !strings.Contains(content, "{{") { + return content, nil + } + + funcMap := template.FuncMap{ + "toJSON": func(v any) (string, error) { + b, err := json.Marshal(v) + if err != nil { + return "", fmt.Errorf("toJSON: %w", err) + } + return string(b), nil + }, + } + + tmpl, err := template.New(name).Funcs(funcMap).Parse(content) + if err != nil { + return "", fmt.Errorf("parse template: %w", err) + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, data); err != nil { + return "", fmt.Errorf("execute template: %w", err) + } + return buf.String(), nil +} + +// normalizePath ensures a path starts with / and has no trailing slash. +func normalizePath(path string) string { + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + return strings.TrimRight(path, "/") +} + diff --git a/internal/generate/v1/routes_test.go b/internal/generate/v1/routes_test.go new file mode 100644 index 0000000..8c9096b --- /dev/null +++ b/internal/generate/v1/routes_test.go @@ -0,0 +1,178 @@ +package v1 + +import ( + "os" + "path/filepath" + "testing" + "testing/fstest" + + "github.com/donbader/agent-sandbox/internal/plugin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCollectPluginRoutes_Namespacing(t *testing.T) { + tmpDir := t.TempDir() + handlerFile := filepath.Join(tmpDir, "handler.go") + require.NoError(t, os.WriteFile(handlerFile, []byte("package custom"), 0644)) + + g := &Generator{projectDir: tmpDir} + resolved := map[string]*resolvedPlugin{ + "@builtin/mcp-oauth": { + def: &plugin.PluginDef{Name: "mcp-oauth", BaseDir: tmpDir}, + rendered: &plugin.Contributions{ + Gateway: plugin.GatewayContrib{ + Routes: []plugin.RouteEntry{ + {Path: "/callback", Handler: "handler.go"}, + }, + }, + }, + }, + } + + routes, err := g.collectPluginRoutes(resolved, tmpDir) + require.NoError(t, err) + require.Len(t, routes, 1) + assert.Equal(t, "/plugins/mcp-oauth/callback", routes[0].Path) + assert.Equal(t, "mcp-oauth", routes[0].PluginName) +} + +func TestCollectPluginRoutes_ConflictDetection(t *testing.T) { + tmpDir := t.TempDir() + handlerFile := filepath.Join(tmpDir, "handler.go") + require.NoError(t, os.WriteFile(handlerFile, []byte("package custom"), 0644)) + + g := &Generator{projectDir: tmpDir} + // Two plugins with the same name would produce the same namespace path + // In practice this can't happen (plugin names are unique), but test the mechanism + resolved := map[string]*resolvedPlugin{ + "@builtin/my-plugin": { + def: &plugin.PluginDef{Name: "my-plugin", BaseDir: tmpDir}, + rendered: &plugin.Contributions{ + Gateway: plugin.GatewayContrib{ + Routes: []plugin.RouteEntry{ + {Path: "/webhook", Handler: "handler.go"}, + {Path: "/webhook", Handler: "handler.go"}, // duplicate + }, + }, + }, + }, + } + + _, err := g.collectPluginRoutes(resolved, tmpDir) + require.Error(t, err) + assert.Contains(t, err.Error(), "route conflict") +} + +func TestCollectPluginRoutes_MissingPath(t *testing.T) { + tmpDir := t.TempDir() + g := &Generator{projectDir: tmpDir} + resolved := map[string]*resolvedPlugin{ + "@builtin/bad": { + def: &plugin.PluginDef{Name: "bad", BaseDir: tmpDir}, + rendered: &plugin.Contributions{ + Gateway: plugin.GatewayContrib{ + Routes: []plugin.RouteEntry{ + {Path: "", Handler: "handler.go"}, + }, + }, + }, + }, + } + + _, err := g.collectPluginRoutes(resolved, tmpDir) + require.Error(t, err) + assert.Contains(t, err.Error(), "must have both path and handler") +} + +func TestCollectPluginRoutes_BundledFS(t *testing.T) { + tmpDir := t.TempDir() + buildDir := filepath.Join(tmpDir, ".build") + require.NoError(t, os.MkdirAll(buildDir, 0755)) + + bundledFS := fstest.MapFS{ + "my-plugin/handlers/callback.go": &fstest.MapFile{Data: []byte("package custom")}, + } + + g := &Generator{projectDir: tmpDir, bundledFS: bundledFS} + resolved := map[string]*resolvedPlugin{ + "@builtin/my-plugin": { + def: &plugin.PluginDef{Name: "my-plugin"}, // no BaseDir = bundled + rendered: &plugin.Contributions{ + Gateway: plugin.GatewayContrib{ + Routes: []plugin.RouteEntry{ + {Path: "/callback", Handler: "./handlers/callback.go"}, + }, + }, + }, + }, + } + + routes, err := g.collectPluginRoutes(resolved, buildDir) + require.NoError(t, err) + require.Len(t, routes, 1) + assert.Equal(t, "/plugins/my-plugin/callback", routes[0].Path) + + // Verify handler was extracted + _, err = os.Stat(routes[0].Handler) + assert.NoError(t, err) +} + +func TestCopyRouteHandlers(t *testing.T) { + tmpDir := t.TempDir() + outDir := filepath.Join(tmpDir, ".build") + require.NoError(t, os.MkdirAll(outDir, 0755)) + + // Create a handler file with a template variable + handlerContent := `package custom + +import "github.com/donbader/agent-sandbox/core/sdk/gateway" + +func init() { + gateway.RegisterRoute(gateway.RouteDef{ + Path: "{{ .path }}", + }) +} +` + handlerFile := filepath.Join(tmpDir, "handler.go") + require.NoError(t, os.WriteFile(handlerFile, []byte(handlerContent), 0644)) + + routes := []RouteRef{ + { + Path: "/plugins/mcp-oauth/callback", + Handler: handlerFile, + PluginName: "mcp-oauth", + }, + } + + err := CopyRouteHandlers(tmpDir, outDir, routes, map[string]any{}) + require.NoError(t, err) + + // Verify the rendered handler exists + destDir := filepath.Join(outDir, "gateway-src", "core", "gateway", "middlewares", "custom") + entries, err := os.ReadDir(destDir) + require.NoError(t, err) + require.Len(t, entries, 1) + + // Verify template was rendered + content, err := os.ReadFile(filepath.Join(destDir, entries[0].Name())) + require.NoError(t, err) + assert.Contains(t, string(content), `/plugins/mcp-oauth/callback`) + assert.NotContains(t, string(content), "{{ .path }}") +} + +func TestNormalizePath(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"/callback", "/callback"}, + {"callback", "/callback"}, + {"/callback/", "/callback"}, + {"/foo/bar/", "/foo/bar"}, + } + for _, tt := range tests { + got := normalizePath(tt.input) + assert.Equal(t, tt.want, got, "normalizePath(%q)", tt.input) + } +} diff --git a/internal/plugin/merge.go b/internal/plugin/merge.go index c5c4c01..7a9b574 100644 --- a/internal/plugin/merge.go +++ b/internal/plugin/merge.go @@ -16,6 +16,7 @@ func MergeContributions(contribs ...*Contributions) *Contributions { merged.Runtime.Volumes = append(merged.Runtime.Volumes, c.Runtime.Volumes...) merged.Gateway.Services = append(merged.Gateway.Services, c.Gateway.Services...) merged.Gateway.Volumes = append(merged.Gateway.Volumes, c.Gateway.Volumes...) + merged.Gateway.Routes = append(merged.Gateway.Routes, c.Gateway.Routes...) for name, svc := range c.Sidecar.Services { merged.Sidecar.Services[name] = svc } diff --git a/internal/plugin/types.go b/internal/plugin/types.go index da30622..ca6453c 100644 --- a/internal/plugin/types.go +++ b/internal/plugin/types.go @@ -42,6 +42,14 @@ type RuntimeContrib struct { type GatewayContrib struct { Services []GatewayService `yaml:"services"` Volumes []string `yaml:"volumes"` + Routes []RouteEntry `yaml:"routes"` +} + +// RouteEntry declares an HTTP route handler contributed by a plugin. +// The path is relative to the plugin's namespace (/plugins/{plugin-name}/...). +type RouteEntry struct { + Path string `yaml:"path"` // relative path (e.g. "/callback") + Handler string `yaml:"handler"` // path to handler .go file } type GatewayService struct { From 63840b9543c914529be2486b5f1668765f50b7df Mon Sep 17 00:00:00 2001 From: "dorey-agent[bot]" <3504508+dorey-agent[bot]@users.noreply.github.com> Date: Mon, 8 Jun 2026 08:50:02 +0000 Subject: [PATCH 2/8] fix: address security review findings in mcp-oauth plugin - CSRF: add nonce-based state parameter (single-use, validated on callback) - Path traversal: state is validated against known provider map (no raw file path construction) - redirect_uri: derived from gateway public_url (baked at generate time), never from request headers - redirect_uri included in authorize URL for proper OAuth round-trip - Non-deterministic provider: use sorted keys for deterministic default selection - client_secret: no longer persisted to token files (stays in gateway binary only) - Error logging: JSON parse failures now logged instead of silently swallowed - Error messages: internal details (endpoint URLs, provider errors) no longer leaked to HTTP responses - HTML escaping: provider name escaped in success page - public_url passed to route handler templates via new parameter --- core/plugins/mcp-oauth/handlers/callback.go | 140 ++++++++++---------- core/plugins/mcp-oauth/middlewares/oauth.go | 24 +++- internal/generate/v1/generator.go | 2 +- internal/generate/v1/routes.go | 9 +- internal/generate/v1/routes_test.go | 2 +- 5 files changed, 97 insertions(+), 80 deletions(-) diff --git a/core/plugins/mcp-oauth/handlers/callback.go b/core/plugins/mcp-oauth/handlers/callback.go index 6b886f8..532adfc 100644 --- a/core/plugins/mcp-oauth/handlers/callback.go +++ b/core/plugins/mcp-oauth/handlers/callback.go @@ -2,37 +2,48 @@ package custom import ( "context" + "crypto/rand" + "encoding/hex" "encoding/json" "fmt" + "html" "io" + "log/slog" "net" "net/http" "net/url" "os" "strings" + "sync" "time" "github.com/donbader/agent-sandbox/core/sdk/gateway" ) -// oauthProviderConfig holds OAuth provider settings baked in at generate time. -type oauthProviderConfig struct { +type oauthCallbackConfig struct { TokenEndpoint string ClientID string ClientSecret string - MCP_URL string } -var oauthCallbackProviders = map[string]oauthProviderConfig{} -var oauthCallbackTokenDir string +var ( + oauthCallbackProviders = map[string]oauthCallbackConfig{} + oauthCallbackTokenDir string + oauthCallbackPublicURL string + oauthNonces = map[string]string{} + oauthNoncesMu sync.Mutex +) func init() { oauthCallbackTokenDir = "{{ .options.token_dir }}" + oauthCallbackPublicURL = "{{ .public_url }}" providersJSON := `{{ toJSON .options.providers }}` var providers map[string]map[string]any - if err := json.Unmarshal([]byte(providersJSON), &providers); err == nil { + if err := json.Unmarshal([]byte(providersJSON), &providers); err != nil { + slog.Error("oauth-callback: failed to parse providers config", "error", err) + } else { for name, cfg := range providers { - p := oauthProviderConfig{} + p := oauthCallbackConfig{} if v, ok := cfg["token_endpoint"].(string); ok { p.TokenEndpoint = v } @@ -42,93 +53,100 @@ func init() { if v, ok := cfg["client_secret"].(string); ok { p.ClientSecret = v } - if v, ok := cfg["mcp_url"].(string); ok { - p.MCP_URL = v - } oauthCallbackProviders[name] = p } } - gateway.RegisterRoute(gateway.RouteDef{ Path: "{{ .path }}", Handler: handleOAuthCallback, }) } +// GenerateOAuthNonce creates a CSRF nonce for the given provider. +func GenerateOAuthNonce(provider string) string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return provider + } + state := hex.EncodeToString(b) + ":" + provider + oauthNoncesMu.Lock() + oauthNonces[state] = provider + oauthNoncesMu.Unlock() + return state +} + +// OAuthCallbackURL returns the full callback URL. +func OAuthCallbackURL() string { + return oauthCallbackPublicURL + "{{ .path }}" +} + func handleOAuthCallback(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") - state := r.URL.Query().Get("state") // state = provider name - + state := r.URL.Query().Get("state") if code == "" { http.Error(w, "missing code parameter", http.StatusBadRequest) return } if state == "" { - http.Error(w, "missing state parameter (provider name)", http.StatusBadRequest) + http.Error(w, "missing state parameter", http.StatusBadRequest) return } - - provider, ok := oauthCallbackProviders[state] + // Validate CSRF nonce + oauthNoncesMu.Lock() + providerName, valid := oauthNonces[state] + if valid { + delete(oauthNonces, state) + } + oauthNoncesMu.Unlock() + if !valid { + http.Error(w, "invalid or expired state", http.StatusForbidden) + return + } + provider, ok := oauthCallbackProviders[providerName] if !ok { - http.Error(w, fmt.Sprintf("unknown provider: %s", state), http.StatusBadRequest) + http.Error(w, "unknown provider", http.StatusBadRequest) return } - if provider.TokenEndpoint == "" { - http.Error(w, fmt.Sprintf("provider %s has no token_endpoint configured", state), http.StatusInternalServerError) + http.Error(w, "provider not configured", http.StatusInternalServerError) return } - - // Exchange authorization code for token - redirectURI := r.URL.Query().Get("redirect_uri") - if redirectURI == "" { - // Reconstruct from request - scheme := "https" - if r.TLS == nil { - scheme = "http" - } - redirectURI = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.URL.Path) - } - - token, err := exchangeCode(provider, code, redirectURI) + redirectURI := OAuthCallbackURL() + token, err := exchangeCodeForToken(provider, code, redirectURI) if err != nil { - http.Error(w, fmt.Sprintf("token exchange failed: %v", err), http.StatusInternalServerError) + slog.Error("oauth-callback: token exchange failed", "provider", providerName, "error", err) + http.Error(w, "token exchange failed", http.StatusInternalServerError) return } - - // Write token file - tokenFile := oauthCallbackTokenDir + "/" + state + ".json" - if err := writeCallbackToken(tokenFile, token, provider); err != nil { - http.Error(w, fmt.Sprintf("failed to save token: %v", err), http.StatusInternalServerError) + tokenFile := oauthCallbackTokenDir + "/" + providerName + ".json" + if err := writeOAuthToken(tokenFile, token, provider); err != nil { + slog.Error("oauth-callback: failed to save token", "provider", providerName, "error", err) + http.Error(w, "failed to save token", http.StatusInternalServerError) return } - - // Register the new access token as a secret for log redaction gateway.RegisterSecret(token.AccessToken) - w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `

Authorization successful

Provider %s has been connected. You can close this tab.

-`, state) +`, html.EscapeString(providerName)) } -type callbackTokenResponse struct { +type oauthTokenExchangeResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int64 `json:"expires_in"` } -func exchangeCode(provider oauthProviderConfig, code, redirectURI string) (*callbackTokenResponse, error) { +func exchangeCodeForToken(provider oauthCallbackConfig, code, redirectURI string) (*oauthTokenExchangeResponse, error) { u, err := url.Parse(provider.TokenEndpoint) if err != nil { - return nil, fmt.Errorf("invalid token_endpoint URL: %w", err) + return nil, fmt.Errorf("invalid token_endpoint: %w", err) } if u.Scheme != "https" { return nil, fmt.Errorf("token_endpoint must use https, got %q", u.Scheme) } - params := url.Values{ "grant_type": {"authorization_code"}, "code": {code}, @@ -138,43 +156,34 @@ func exchangeCode(provider oauthProviderConfig, code, redirectURI string) (*call if provider.ClientSecret != "" { params.Set("client_secret", provider.ClientSecret) } - client := &http.Client{ Timeout: 30 * time.Second, Transport: callbackSSRFSafeTransport(), } - - resp, err := client.Post( - provider.TokenEndpoint, - "application/x-www-form-urlencoded", - strings.NewReader(params.Encode()), - ) + resp, err := client.Post(provider.TokenEndpoint, "application/x-www-form-urlencoded", strings.NewReader(params.Encode())) if err != nil { - return nil, fmt.Errorf("token request to %s: %w", provider.TokenEndpoint, err) + return nil, fmt.Errorf("request failed: %w", err) } defer func() { _ = resp.Body.Close() }() - body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { - return nil, fmt.Errorf("reading token response: %w", err) + return nil, fmt.Errorf("reading response: %w", err) } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("token endpoint returned %d: %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("token endpoint returned %d", resp.StatusCode) } - - var tr callbackTokenResponse + var tr oauthTokenExchangeResponse if err := json.Unmarshal(body, &tr); err != nil { - return nil, fmt.Errorf("parsing token response: %w", err) + return nil, fmt.Errorf("parsing response: %w", err) } return &tr, nil } -func writeCallbackToken(path string, token *callbackTokenResponse, provider oauthProviderConfig) error { +func writeOAuthToken(path string, token *oauthTokenExchangeResponse, provider oauthCallbackConfig) error { expiresIn := token.ExpiresIn if expiresIn == 0 { expiresIn = 3600 } - stored := map[string]any{ "access_token": token.AccessToken, "expires_at": time.Now().Unix() + expiresIn, @@ -184,15 +193,10 @@ func writeCallbackToken(path string, token *callbackTokenResponse, provider oaut if token.RefreshToken != "" { stored["refresh_token"] = token.RefreshToken } - if provider.ClientSecret != "" { - stored["client_secret"] = provider.ClientSecret - } - data, err := json.MarshalIndent(stored, "", " ") if err != nil { return err } - tmp := path + ".tmp" if err := os.WriteFile(tmp, data, 0600); err != nil { return err @@ -213,7 +217,7 @@ func callbackSSRFSafeTransport() *http.Transport { } for _, ip := range ips { if ip.IP.IsLoopback() || ip.IP.IsPrivate() || ip.IP.IsLinkLocalUnicast() || ip.IP.IsLinkLocalMulticast() { - return nil, fmt.Errorf("refusing to connect to private IP %s (resolved from %s)", ip.IP, host) + return nil, fmt.Errorf("refusing to connect to private IP %s", ip.IP) } } dialer := &net.Dialer{Timeout: 10 * time.Second} diff --git a/core/plugins/mcp-oauth/middlewares/oauth.go b/core/plugins/mcp-oauth/middlewares/oauth.go index 739d73b..52b6c36 100644 --- a/core/plugins/mcp-oauth/middlewares/oauth.go +++ b/core/plugins/mcp-oauth/middlewares/oauth.go @@ -11,6 +11,7 @@ import ( "net/http" "net/url" "os" + "sort" "strings" "sync" "time" @@ -24,7 +25,6 @@ type storedToken struct { ExpiresAt int64 `json:"expires_at"` TokenEndpoint string `json:"token_endpoint"` ClientID string `json:"client_id"` - ClientSecret *string `json:"client_secret"` } type oauthProviderInfo struct { @@ -49,7 +49,9 @@ func init() { providersJSON := `{{ toJSON .options.providers }}` providers := make(map[string]oauthProviderInfo) var rawProviders map[string]map[string]any - if err := json.Unmarshal([]byte(providersJSON), &rawProviders); err == nil { + if err := json.Unmarshal([]byte(providersJSON), &rawProviders); err != nil { + slog.Error("oauth: failed to parse providers config", "error", err) + } else { for name, cfg := range rawProviders { p := oauthProviderInfo{} if v, ok := cfg["authorize_endpoint"].(string); ok { @@ -65,10 +67,15 @@ func init() { } } + // Deterministic default: first provider alphabetically var defaultProvider string - for name := range providers { - defaultProvider = name - break + if len(providers) > 0 { + names := make([]string, 0, len(providers)) + for name := range providers { + names = append(names, name) + } + sort.Strings(names) + defaultProvider = names[0] } tokenFile := tokenDir + "/" + defaultProvider + ".json" @@ -117,10 +124,15 @@ func init() { } func buildAuthorizeURL(provider oauthProviderInfo, providerName string) string { + // Generate CSRF nonce (validated on callback) + state := GenerateOAuthNonce(providerName) + callbackURL := OAuthCallbackURL() + params := url.Values{ "client_id": {provider.ClientID}, "response_type": {"code"}, - "state": {providerName}, + "state": {state}, + "redirect_uri": {callbackURL}, } if provider.Scopes != "" { params.Set("scope", provider.Scopes) diff --git a/internal/generate/v1/generator.go b/internal/generate/v1/generator.go index 662e501..f4faff9 100644 --- a/internal/generate/v1/generator.go +++ b/internal/generate/v1/generator.go @@ -295,7 +295,7 @@ func (g *Generator) generateAgent(cfg *config.Config, agentDir, buildDir string) } if len(gwCfg.Routes) > 0 { allOpts := collectAllOptions(cfg) - if err := CopyRouteHandlers(g.projectDir, buildDir, gwCfg.Routes, allOpts); err != nil { + if err := CopyRouteHandlers(g.projectDir, buildDir, gwCfg.Routes, allOpts, gwCfg.PublicURL); err != nil { return nil, fmt.Errorf("copy route handlers: %w", err) } } diff --git a/internal/generate/v1/routes.go b/internal/generate/v1/routes.go index 58c10f5..632e2a4 100644 --- a/internal/generate/v1/routes.go +++ b/internal/generate/v1/routes.go @@ -67,7 +67,7 @@ func (g *Generator) resolveRouteHandler(pluginName, handler, baseDir, buildDir s // CopyRouteHandlers copies route handler .go files into the gateway build context. // Each handler is a Go template rendered with path and options, and self-registers // via init() calling gateway.RegisterRoute() — same pattern as custom middleware. -func CopyRouteHandlers(projectDir, outDir string, routes []RouteRef, opts map[string]any) error { +func CopyRouteHandlers(projectDir, outDir string, routes []RouteRef, opts map[string]any, publicURL string) error { if len(routes) == 0 { return nil } @@ -92,10 +92,11 @@ func CopyRouteHandlers(projectDir, outDir string, routes []RouteRef, opts map[st return fmt.Errorf("read route handler %s: %w", route.Handler, err) } - // Template data includes options and the namespaced path + // Template data includes options, the namespaced path, and public_url data := map[string]any{ - "options": resolved, - "path": route.Path, + "options": resolved, + "path": route.Path, + "public_url": publicURL, } // Template-render the handler file diff --git a/internal/generate/v1/routes_test.go b/internal/generate/v1/routes_test.go index 8c9096b..55f26c6 100644 --- a/internal/generate/v1/routes_test.go +++ b/internal/generate/v1/routes_test.go @@ -145,7 +145,7 @@ func init() { }, } - err := CopyRouteHandlers(tmpDir, outDir, routes, map[string]any{}) + err := CopyRouteHandlers(tmpDir, outDir, routes, map[string]any{}, "https://gateway.example.com") require.NoError(t, err) // Verify the rendered handler exists From 3cbbedaf9f6e310a0d17ea8e6fca2b60e34b7822 Mon Sep 17 00:00:00 2001 From: "dorey-agent[bot]" <3504508+dorey-agent[bot]@users.noreply.github.com> Date: Mon, 8 Jun 2026 10:47:07 +0000 Subject: [PATCH 3/8] fix: resolve CI build failures in mcp-oauth plugin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove cross-package function calls (GenerateOAuthNonce, OAuthCallbackURL) between handlers/ and middlewares/ — they compile as separate packages in CI - Use deterministic HMAC-based CSRF state derived from providers config (both packages derive the same key independently) - Remove ClientSecret field from storedToken struct (no longer persisted) - Remove unused crypto/rand import --- core/plugins/mcp-oauth/handlers/callback.go | 77 ++++++++------------- core/plugins/mcp-oauth/middlewares/oauth.go | 28 +++++--- internal/generate/v1/generator.go | 4 ++ 3 files changed, 53 insertions(+), 56 deletions(-) diff --git a/core/plugins/mcp-oauth/handlers/callback.go b/core/plugins/mcp-oauth/handlers/callback.go index 532adfc..d4ae5b3 100644 --- a/core/plugins/mcp-oauth/handlers/callback.go +++ b/core/plugins/mcp-oauth/handlers/callback.go @@ -2,7 +2,8 @@ package custom import ( "context" - "crypto/rand" + "crypto/hmac" + "crypto/sha256" "encoding/hex" "encoding/json" "fmt" @@ -14,7 +15,6 @@ import ( "net/url" "os" "strings" - "sync" "time" "github.com/donbader/agent-sandbox/core/sdk/gateway" @@ -29,18 +29,20 @@ type oauthCallbackConfig struct { var ( oauthCallbackProviders = map[string]oauthCallbackConfig{} oauthCallbackTokenDir string - oauthCallbackPublicURL string - oauthNonces = map[string]string{} - oauthNoncesMu sync.Mutex + oauthCallbackHMACKey []byte ) func init() { oauthCallbackTokenDir = "{{ .options.token_dir }}" - oauthCallbackPublicURL = "{{ .public_url }}" providersJSON := `{{ toJSON .options.providers }}` + + // Derive HMAC key from providers config (same derivation as middleware) + h := sha256.Sum256([]byte(providersJSON)) + oauthCallbackHMACKey = h[:] + var providers map[string]map[string]any if err := json.Unmarshal([]byte(providersJSON), &providers); err != nil { - slog.Error("oauth-callback: failed to parse providers config", "error", err) + slog.Error("oauth-callback: failed to parse providers", "error", err) } else { for name, cfg := range providers { p := oauthCallbackConfig{} @@ -62,24 +64,6 @@ func init() { }) } -// GenerateOAuthNonce creates a CSRF nonce for the given provider. -func GenerateOAuthNonce(provider string) string { - b := make([]byte, 16) - if _, err := rand.Read(b); err != nil { - return provider - } - state := hex.EncodeToString(b) + ":" + provider - oauthNoncesMu.Lock() - oauthNonces[state] = provider - oauthNoncesMu.Unlock() - return state -} - -// OAuthCallbackURL returns the full callback URL. -func OAuthCallbackURL() string { - return oauthCallbackPublicURL + "{{ .path }}" -} - func handleOAuthCallback(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") state := r.URL.Query().Get("state") @@ -91,15 +75,18 @@ func handleOAuthCallback(w http.ResponseWriter, r *http.Request) { http.Error(w, "missing state parameter", http.StatusBadRequest) return } - // Validate CSRF nonce - oauthNoncesMu.Lock() - providerName, valid := oauthNonces[state] - if valid { - delete(oauthNonces, state) + // Validate HMAC state: format is "hmac_sig:provider_name" + parts := strings.SplitN(state, ":", 2) + if len(parts) != 2 { + http.Error(w, "invalid state format", http.StatusForbidden) + return } - oauthNoncesMu.Unlock() - if !valid { - http.Error(w, "invalid or expired state", http.StatusForbidden) + sig, providerName := parts[0], parts[1] + mac := hmac.New(sha256.New, oauthCallbackHMACKey) + mac.Write([]byte(providerName)) + expectedSig := hex.EncodeToString(mac.Sum(nil))[:16] + if !hmac.Equal([]byte(sig), []byte(expectedSig)) { + http.Error(w, "invalid state signature", http.StatusForbidden) return } provider, ok := oauthCallbackProviders[providerName] @@ -111,25 +98,24 @@ func handleOAuthCallback(w http.ResponseWriter, r *http.Request) { http.Error(w, "provider not configured", http.StatusInternalServerError) return } - redirectURI := OAuthCallbackURL() + redirectURI := "{{ .public_url }}{{ .path }}" token, err := exchangeCodeForToken(provider, code, redirectURI) if err != nil { - slog.Error("oauth-callback: token exchange failed", "provider", providerName, "error", err) + slog.Error("oauth-callback: token exchange failed", "error", err) http.Error(w, "token exchange failed", http.StatusInternalServerError) return } tokenFile := oauthCallbackTokenDir + "/" + providerName + ".json" if err := writeOAuthToken(tokenFile, token, provider); err != nil { - slog.Error("oauth-callback: failed to save token", "provider", providerName, "error", err) + slog.Error("oauth-callback: failed to save token", "error", err) http.Error(w, "failed to save token", http.StatusInternalServerError) return } gateway.RegisterSecret(token.AccessToken) w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `

Authorization successful

-

Provider %s has been connected. You can close this tab.

+

Provider %s connected. You can close this tab.

`, html.EscapeString(providerName)) } @@ -156,10 +142,7 @@ func exchangeCodeForToken(provider oauthCallbackConfig, code, redirectURI string if provider.ClientSecret != "" { params.Set("client_secret", provider.ClientSecret) } - client := &http.Client{ - Timeout: 30 * time.Second, - Transport: callbackSSRFSafeTransport(), - } + client := &http.Client{Timeout: 30 * time.Second, Transport: cbSSRFSafe()} resp, err := client.Post(provider.TokenEndpoint, "application/x-www-form-urlencoded", strings.NewReader(params.Encode())) if err != nil { return nil, fmt.Errorf("request failed: %w", err) @@ -179,16 +162,14 @@ func exchangeCodeForToken(provider oauthCallbackConfig, code, redirectURI string return &tr, nil } -func writeOAuthToken(path string, token *oauthTokenExchangeResponse, provider oauthCallbackConfig) error { +func writeOAuthToken(path string, token *oauthTokenExchangeResponse, _ oauthCallbackConfig) error { expiresIn := token.ExpiresIn if expiresIn == 0 { expiresIn = 3600 } stored := map[string]any{ - "access_token": token.AccessToken, - "expires_at": time.Now().Unix() + expiresIn, - "token_endpoint": provider.TokenEndpoint, - "client_id": provider.ClientID, + "access_token": token.AccessToken, + "expires_at": time.Now().Unix() + expiresIn, } if token.RefreshToken != "" { stored["refresh_token"] = token.RefreshToken @@ -204,7 +185,7 @@ func writeOAuthToken(path string, token *oauthTokenExchangeResponse, provider oa return os.Rename(tmp, path) } -func callbackSSRFSafeTransport() *http.Transport { +func cbSSRFSafe() *http.Transport { return &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { host, port, err := net.SplitHostPort(addr) diff --git a/core/plugins/mcp-oauth/middlewares/oauth.go b/core/plugins/mcp-oauth/middlewares/oauth.go index 52b6c36..6487c09 100644 --- a/core/plugins/mcp-oauth/middlewares/oauth.go +++ b/core/plugins/mcp-oauth/middlewares/oauth.go @@ -2,6 +2,9 @@ package custom import ( "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -42,11 +45,18 @@ type oauthState struct { httpClient *http.Client } +// oauthHMACKey is derived deterministically from providers config for CSRF state signing. +var oauthHMACKey []byte + func init() { + // Derive HMAC key from providers config (same derivation as callback handler) + providersJSON := `{{ toJSON .options.providers }}` + h := sha256.Sum256([]byte(providersJSON)) + oauthHMACKey = h[:] + tokenDir := "{{ .options.token_dir }}" domains := strings.Split("{{ .domainsList }}", ",") - providersJSON := `{{ toJSON .options.providers }}` providers := make(map[string]oauthProviderInfo) var rawProviders map[string]map[string]any if err := json.Unmarshal([]byte(providersJSON), &rawProviders); err != nil { @@ -124,9 +134,13 @@ func init() { } func buildAuthorizeURL(provider oauthProviderInfo, providerName string) string { - // Generate CSRF nonce (validated on callback) - state := GenerateOAuthNonce(providerName) - callbackURL := OAuthCallbackURL() + // Generate CSRF state: HMAC(provider) ensures only this gateway can validate + mac := hmac.New(sha256.New, oauthHMACKey) + mac.Write([]byte(providerName)) + sig := hex.EncodeToString(mac.Sum(nil))[:16] + state := sig + ":" + providerName + + callbackURL := "{{ .options.callback_url }}" params := url.Values{ "client_id": {provider.ClientID}, @@ -232,9 +246,8 @@ func (s *oauthState) refreshToken(stored *storedToken) (*storedToken, error) { "refresh_token": {*stored.RefreshToken}, "client_id": {stored.ClientID}, } - if stored.ClientSecret != nil && *stored.ClientSecret != "" { - params.Set("client_secret", *stored.ClientSecret) - } + // client_secret is baked into the providers config at generate time + // (not stored in token file for security) resp, err := s.httpClient.Post( stored.TokenEndpoint, "application/x-www-form-urlencoded", @@ -269,7 +282,6 @@ func (s *oauthState) refreshToken(stored *storedToken) (*storedToken, error) { ExpiresAt: time.Now().Unix() + expiresIn, TokenEndpoint: stored.TokenEndpoint, ClientID: stored.ClientID, - ClientSecret: stored.ClientSecret, }, nil } diff --git a/internal/generate/v1/generator.go b/internal/generate/v1/generator.go index f4faff9..8d3d40c 100644 --- a/internal/generate/v1/generator.go +++ b/internal/generate/v1/generator.go @@ -284,6 +284,10 @@ func (g *Generator) generateAgent(cfg *config.Config, agentDir, buildDir string) } if len(gwCfg.Middlewares) > 0 { allOpts := collectAllOptions(cfg) + // Inject computed callback URLs for plugins with routes + for _, route := range gwCfg.Routes { + allOpts["callback_url"] = gwCfg.PublicURL + route.Path + } if err := CopyCustomMiddleware(g.projectDir, buildDir, gwCfg.Middlewares, allOpts); err != nil { return nil, fmt.Errorf("copy middleware: %w", err) } From 3bbdae594b815566beba507d731dfc7f7b4fd651 Mon Sep 17 00:00:00 2001 From: "dorey-agent[bot]" <3504508+dorey-agent[bot]@users.noreply.github.com> Date: Mon, 8 Jun 2026 10:51:04 +0000 Subject: [PATCH 4/8] fix: handle errcheck lint for fmt.Fprintf in callback handler --- core/plugins/mcp-oauth/handlers/callback.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/plugins/mcp-oauth/handlers/callback.go b/core/plugins/mcp-oauth/handlers/callback.go index d4ae5b3..bd97861 100644 --- a/core/plugins/mcp-oauth/handlers/callback.go +++ b/core/plugins/mcp-oauth/handlers/callback.go @@ -113,7 +113,7 @@ func handleOAuthCallback(w http.ResponseWriter, r *http.Request) { } gateway.RegisterSecret(token.AccessToken) w.Header().Set("Content-Type", "text/html; charset=utf-8") - fmt.Fprintf(w, ` + _, _ = fmt.Fprintf(w, `

Authorization successful

Provider %s connected. You can close this tab.

`, html.EscapeString(providerName)) From 16fb635643765940671b3882f721e7ce337ff7aa Mon Sep 17 00:00:00 2001 From: "dorey-agent[bot]" <3504508+dorey-agent[bot]@users.noreply.github.com> Date: Mon, 8 Jun 2026 14:43:32 +0000 Subject: [PATCH 5/8] feat: add dynamic OAuth client registration (RFC 7591) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Gateway SDK: DiscoverOAuthMetadata() + RegisterOAuthClient() utilities - Middleware auto-detects mode: client_id present → static, absent → dynamic - Dynamic mode: discovers .well-known/oauth-authorization-server, registers client - Registration response cached to {token_dir}/{provider}.reg.json for reuse - Updated README with dual-mode examples (dynamic for Notion, static for custom) - For Notion/MCP providers: just provide mcp_url, everything else auto-discovered --- core/plugins/mcp-oauth/README.md | 21 +++- core/plugins/mcp-oauth/middlewares/oauth.go | 107 +++++++++++++--- core/sdk/gateway/oauth_discovery.go | 133 ++++++++++++++++++++ 3 files changed, 243 insertions(+), 18 deletions(-) create mode 100644 core/sdk/gateway/oauth_discovery.go diff --git a/core/plugins/mcp-oauth/README.md b/core/plugins/mcp-oauth/README.md index ae7b3aa..962c958 100644 --- a/core/plugins/mcp-oauth/README.md +++ b/core/plugins/mcp-oauth/README.md @@ -21,10 +21,15 @@ installations: - plugin: "@builtin/mcp-oauth" options: providers: + # Dynamic mode: just provide mcp_url — credentials auto-discovered notion: mcp_url: https://mcp.notion.com/mcp - authorize_endpoint: https://api.notion.com/v1/oauth/authorize - token_endpoint: https://api.notion.com/v1/oauth/token + + # Static mode: provide all OAuth details manually + custom-provider: + mcp_url: https://custom.example.com/mcp + authorize_endpoint: https://custom.example.com/oauth/authorize + token_endpoint: https://custom.example.com/oauth/token client_id: "your-client-id" client_secret: "your-client-secret" scopes: "read_content" @@ -40,7 +45,15 @@ installations: ### Provider Config -Each provider entry supports: +Each provider entry supports two modes: + +**Dynamic mode** (recommended for MCP servers that support RFC 7591): + +| Field | Required | Description | +|-------|----------|-------------| +| `mcp_url` | yes | MCP server endpoint — metadata + registration auto-discovered | + +**Static mode** (for providers without dynamic registration): | Field | Required | Description | |-------|----------|-------------| @@ -51,6 +64,8 @@ Each provider entry supports: | `client_secret` | no | OAuth client secret | | `scopes` | no | Space-separated scopes | +Mode is auto-detected: if `client_id` is absent, dynamic mode is used. + ## What It Contributes - **Gateway middleware:** Token injection + 401 with authorize URL when unauthenticated diff --git a/core/plugins/mcp-oauth/middlewares/oauth.go b/core/plugins/mcp-oauth/middlewares/oauth.go index 6487c09..ee8f812 100644 --- a/core/plugins/mcp-oauth/middlewares/oauth.go +++ b/core/plugins/mcp-oauth/middlewares/oauth.go @@ -32,8 +32,11 @@ type storedToken struct { type oauthProviderInfo struct { AuthorizeEndpoint string + TokenEndpoint string ClientID string + ClientSecret string Scopes string + Dynamic bool } type oauthState struct { @@ -45,16 +48,15 @@ type oauthState struct { httpClient *http.Client } -// oauthHMACKey is derived deterministically from providers config for CSRF state signing. var oauthHMACKey []byte func init() { - // Derive HMAC key from providers config (same derivation as callback handler) providersJSON := `{{ toJSON .options.providers }}` h := sha256.Sum256([]byte(providersJSON)) oauthHMACKey = h[:] tokenDir := "{{ .options.token_dir }}" + callbackURL := "{{ .options.callback_url }}" domains := strings.Split("{{ .domainsList }}", ",") providers := make(map[string]oauthProviderInfo) @@ -67,17 +69,34 @@ func init() { if v, ok := cfg["authorize_endpoint"].(string); ok { p.AuthorizeEndpoint = v } + if v, ok := cfg["token_endpoint"].(string); ok { + p.TokenEndpoint = v + } if v, ok := cfg["client_id"].(string); ok { p.ClientID = v } + if v, ok := cfg["client_secret"].(string); ok { + p.ClientSecret = v + } if v, ok := cfg["scopes"].(string); ok { p.Scopes = v } + // Dynamic mode: no client_id means we need discovery+registration + if p.ClientID == "" { + p.Dynamic = true + mcpURL, _ := cfg["mcp_url"].(string) + resolved := resolveProvider(mcpURL, callbackURL, tokenDir, name) + if resolved != nil { + p.AuthorizeEndpoint = resolved.AuthorizeEndpoint + p.TokenEndpoint = resolved.TokenEndpoint + p.ClientID = resolved.ClientID + p.ClientSecret = resolved.ClientSecret + } + } providers[name] = p } } - // Deterministic default: first provider alphabetically var defaultProvider string if len(providers) > 0 { names := make([]string, 0, len(providers)) @@ -99,7 +118,7 @@ func init() { } if _, err := os.Stat(tokenFile); err != nil { - slog.Warn("oauth token file not found at startup", "path", tokenFile, "error", err) + slog.Warn("oauth: token file not found at startup", "path", tokenFile) } for _, s := range oauthSecrets(tokenFile) { gateway.RegisterSecret(s) @@ -112,8 +131,8 @@ func init() { token, err := state.getValidToken(tokenFile) if err != nil { if errors.Is(err, os.ErrNotExist) { - if provider, ok := state.providers[defaultProvider]; ok && provider.AuthorizeEndpoint != "" { - authorizeURL := buildAuthorizeURL(provider, defaultProvider) + if p, ok := state.providers[defaultProvider]; ok && p.AuthorizeEndpoint != "" { + authorizeURL := mwBuildAuthorizeURL(p, defaultProvider, callbackURL) ctx.SetAbortHeader("X-OAuth-Authorize-URL", authorizeURL) ctx.SetAbortHeader("Content-Type", "application/json") ctx.Abort(http.StatusUnauthorized, fmt.Sprintf( @@ -133,15 +152,75 @@ func init() { }) } -func buildAuthorizeURL(provider oauthProviderInfo, providerName string) string { - // Generate CSRF state: HMAC(provider) ensures only this gateway can validate +// resolveProvider does OAuth metadata discovery + dynamic client registration. +// Returns nil if discovery fails (provider will be skipped). +func resolveProvider(mcpURL, callbackURL, tokenDir, name string) *oauthProviderInfo { + if mcpURL == "" { + slog.Error("oauth: dynamic provider has no mcp_url", "provider", name) + return nil + } + + // Check for cached registration + regFile := tokenDir + "/" + name + ".reg.json" + if data, err := os.ReadFile(regFile); err == nil { + var cached struct { + AuthorizeEndpoint string `json:"authorize_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + } + if json.Unmarshal(data, &cached) == nil && cached.ClientID != "" { + slog.Info("oauth: using cached dynamic registration", "provider", name) + return &oauthProviderInfo{ + AuthorizeEndpoint: cached.AuthorizeEndpoint, + TokenEndpoint: cached.TokenEndpoint, + ClientID: cached.ClientID, + ClientSecret: cached.ClientSecret, + } + } + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + meta, err := gateway.DiscoverOAuthMetadata(ctx, mcpURL) + if err != nil { + slog.Error("oauth: metadata discovery failed", "provider", name, "error", err) + return nil + } + + reg, err := gateway.RegisterOAuthClient(ctx, meta.RegistrationEndpoint, []string{callbackURL}, "agent-sandbox:"+name) + if err != nil { + slog.Error("oauth: dynamic registration failed", "provider", name, "error", err) + return nil + } + + // Persist registration for reuse across restarts + regData, _ := json.MarshalIndent(map[string]string{ + "authorize_endpoint": meta.AuthorizationEndpoint, + "token_endpoint": meta.TokenEndpoint, + "client_id": reg.ClientID, + "client_secret": reg.ClientSecret, + }, "", " ") + if err := os.WriteFile(regFile, regData, 0600); err != nil { + slog.Warn("oauth: failed to cache registration", "provider", name, "error", err) + } + + slog.Info("oauth: dynamic registration complete", "provider", name, "client_id", reg.ClientID) + return &oauthProviderInfo{ + AuthorizeEndpoint: meta.AuthorizationEndpoint, + TokenEndpoint: meta.TokenEndpoint, + ClientID: reg.ClientID, + ClientSecret: reg.ClientSecret, + } +} + +func mwBuildAuthorizeURL(provider oauthProviderInfo, providerName, callbackURL string) string { mac := hmac.New(sha256.New, oauthHMACKey) mac.Write([]byte(providerName)) sig := hex.EncodeToString(mac.Sum(nil))[:16] state := sig + ":" + providerName - callbackURL := "{{ .options.callback_url }}" - params := url.Values{ "client_id": {provider.ClientID}, "response_type": {"code"}, @@ -246,15 +325,13 @@ func (s *oauthState) refreshToken(stored *storedToken) (*storedToken, error) { "refresh_token": {*stored.RefreshToken}, "client_id": {stored.ClientID}, } - // client_secret is baked into the providers config at generate time - // (not stored in token file for security) resp, err := s.httpClient.Post( stored.TokenEndpoint, "application/x-www-form-urlencoded", strings.NewReader(params.Encode()), ) if err != nil { - return nil, fmt.Errorf("refresh request to %s: %w", stored.TokenEndpoint, err) + return nil, fmt.Errorf("refresh request: %w", err) } defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) @@ -262,7 +339,7 @@ func (s *oauthState) refreshToken(stored *storedToken) (*storedToken, error) { return nil, fmt.Errorf("reading refresh response: %w", err) } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("token refresh returned %d: %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("token refresh returned %d", resp.StatusCode) } var tr oauthTokenResponse if err := json.Unmarshal(body, &tr); err != nil { @@ -298,7 +375,7 @@ func oauthSSRFSafeTransport() *http.Transport { } for _, ip := range ips { if ip.IP.IsLoopback() || ip.IP.IsPrivate() || ip.IP.IsLinkLocalUnicast() || ip.IP.IsLinkLocalMulticast() { - return nil, fmt.Errorf("oauth: refusing to connect to private IP %s (resolved from %s)", ip.IP, host) + return nil, fmt.Errorf("oauth: refusing to connect to private IP %s", ip.IP) } } dialer := &net.Dialer{Timeout: 10 * time.Second} diff --git a/core/sdk/gateway/oauth_discovery.go b/core/sdk/gateway/oauth_discovery.go new file mode 100644 index 0000000..91b826a --- /dev/null +++ b/core/sdk/gateway/oauth_discovery.go @@ -0,0 +1,133 @@ +package gateway + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// OAuthMetadata holds discovered OAuth server metadata (RFC 8414). +type OAuthMetadata struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + RegistrationEndpoint string `json:"registration_endpoint"` + ScopesSupported []string `json:"scopes_supported"` +} + +// OAuthRegistration holds dynamic client registration response (RFC 7591). +type OAuthRegistration struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` + RedirectURIs []string `json:"redirect_uris,omitempty"` +} + +// DiscoverOAuthMetadata fetches OAuth server metadata from the MCP server. +// Tries .well-known/oauth-authorization-server first, falls back to .well-known/openid-configuration. +func DiscoverOAuthMetadata(ctx context.Context, mcpURL string) (*OAuthMetadata, error) { + base, err := url.Parse(mcpURL) + if err != nil { + return nil, fmt.Errorf("parse mcp_url: %w", err) + } + + // Try RFC 8414 path first + wellKnownPaths := []string{ + "/.well-known/oauth-authorization-server", + "/.well-known/openid-configuration", + } + + client := &http.Client{Timeout: 15 * time.Second} + + for _, path := range wellKnownPaths { + metaURL := base.Scheme + "://" + base.Host + path + req, err := http.NewRequestWithContext(ctx, "GET", metaURL, nil) + if err != nil { + continue + } + resp, err := client.Do(req) + if err != nil { + continue + } + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + continue + } + + var meta OAuthMetadata + if err := json.Unmarshal(body, &meta); err != nil { + continue + } + if meta.AuthorizationEndpoint != "" && meta.TokenEndpoint != "" { + return &meta, nil + } + } + + return nil, fmt.Errorf("oauth metadata not found at %s", base.Host) +} + +// RegisterOAuthClient performs Dynamic Client Registration (RFC 7591). +func RegisterOAuthClient(ctx context.Context, registrationEndpoint string, redirectURIs []string, clientName string) (*OAuthRegistration, error) { + if registrationEndpoint == "" { + return nil, fmt.Errorf("no registration_endpoint available") + } + + u, err := url.Parse(registrationEndpoint) + if err != nil { + return nil, fmt.Errorf("invalid registration_endpoint: %w", err) + } + if u.Scheme != "https" { + return nil, fmt.Errorf("registration_endpoint must use https, got %q", u.Scheme) + } + + reqBody := map[string]any{ + "client_name": clientName, + "redirect_uris": redirectURIs, + "grant_types": []string{"authorization_code", "refresh_token"}, + "response_types": []string{"code"}, + "token_endpoint_auth_method": "client_secret_post", + } + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal registration request: %w", err) + } + + client := &http.Client{Timeout: 15 * time.Second} + req, err := http.NewRequestWithContext(ctx, "POST", registrationEndpoint, strings.NewReader(string(bodyBytes))) + if err != nil { + return nil, fmt.Errorf("create registration request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("registration request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("reading registration response: %w", err) + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + return nil, fmt.Errorf("registration returned %d: %s", resp.StatusCode, string(body)) + } + + var reg OAuthRegistration + if err := json.Unmarshal(body, ®); err != nil { + return nil, fmt.Errorf("parse registration response: %w", err) + } + if reg.ClientID == "" { + return nil, fmt.Errorf("registration response missing client_id") + } + + return ®, nil +} From acd5ac03fc41f4bd310636593bc6d1c327cae461 Mon Sep 17 00:00:00 2001 From: "dorey-agent[bot]" <3504508+dorey-agent[bot]@users.noreply.github.com> Date: Mon, 8 Jun 2026 14:54:19 +0000 Subject: [PATCH 6/8] feat: localhost fallback for OAuth callbacks + Notion MCP example - Gateway main.go: route handler dispatch on health server (port 8080) - Compose: auto-expose gateway port 8080 when plugin routes are registered - Default public_url to http://localhost:8080 when not configured - local-coding example: add Notion MCP with mcp-oauth plugin (dynamic mode) - README: document Notion OAuth setup flow (zero-config, just mcp_url) --- core/gateway/cmd/gateway/main.go | 11 +++++++++- examples/local-coding/README.md | 28 ++++++++++++++++++++++++++ examples/local-coding/agent.yaml | 6 ++++++ internal/generate/v1/compose.go | 4 ++++ internal/generate/v1/gateway_config.go | 8 +++++++- 5 files changed, 55 insertions(+), 2 deletions(-) diff --git a/core/gateway/cmd/gateway/main.go b/core/gateway/cmd/gateway/main.go index 6df3a9a..4c85f37 100644 --- a/core/gateway/cmd/gateway/main.go +++ b/core/gateway/cmd/gateway/main.go @@ -120,7 +120,7 @@ func main() { }() } - // Health endpoint + // Health + route handler endpoint healthAddr := ":8080" go func() { mux := http.NewServeMux() @@ -128,6 +128,15 @@ func main() { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) }) + // Serve plugin-registered routes (e.g. /plugins/mcp-oauth/callback) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + handler := gateway.MatchRoute(r.URL.Path) + if handler != nil { + handler(w, r) + return + } + http.NotFound(w, r) + }) if err := http.ListenAndServe(healthAddr, mux); err != nil { slog.Error("health server error", "error", err) } diff --git a/examples/local-coding/README.md b/examples/local-coding/README.md index 2d90e7a..a5fd204 100644 --- a/examples/local-coding/README.md +++ b/examples/local-coding/README.md @@ -89,3 +89,31 @@ With `volume: true`, the home directory persists across container restarts (shel | Variable | Description | |----------|-------------| | `STX_LLM_GATEWAY_API_KEY` | API key for the LLM gateway | + +## Notion MCP Integration + +This example includes the `mcp-oauth` plugin configured for Notion. On first use: + +1. Start the sandbox: + ```bash + agent-sandbox generate + agent-sandbox compose up --build -d + ``` + +2. When the agent tries to access Notion, the gateway returns a 401 with an authorize URL. The URL is also available at: + ``` + http://localhost:8080/plugins/mcp-oauth/callback + ``` + +3. The first request to Notion's MCP will trigger OAuth discovery and dynamic client registration automatically. The gateway logs the authorize URL: + ```bash + agent-sandbox compose logs coder-gateway | grep authorize_url + ``` + +4. Open the authorize URL in your browser, log in to Notion, and authorize access. + +5. Notion redirects to `http://localhost:8080/plugins/mcp-oauth/callback` — the gateway exchanges the code for tokens and stores them automatically. + +6. Subsequent requests to Notion are authenticated transparently. + +> **Note:** No Notion developer app or API key needed. The plugin uses OAuth Dynamic Client Registration (RFC 7591) — credentials are obtained automatically. diff --git a/examples/local-coding/agent.yaml b/examples/local-coding/agent.yaml index 5ea20c9..aec2924 100644 --- a/examples/local-coding/agent.yaml +++ b/examples/local-coding/agent.yaml @@ -10,8 +10,14 @@ gateway: - url: https://agent-gateway.stx-ai.net headers: Authorization: Bearer ${STX_LLM_GATEWAY_API_KEY} + - url: https://mcp.notion.com installations: - plugin: "@builtin/home-override" options: home_directory: "./home" volume: true + - plugin: "@builtin/mcp-oauth" + options: + providers: + notion: + mcp_url: https://mcp.notion.com/mcp diff --git a/internal/generate/v1/compose.go b/internal/generate/v1/compose.go index 30132d6..329d73a 100644 --- a/internal/generate/v1/compose.go +++ b/internal/generate/v1/compose.go @@ -110,6 +110,10 @@ func BuildCompose(cfg *config.Config, contribs *plugin.Contributions, projectDir "retries": 3, }, } + // Expose gateway HTTP port when plugin routes are registered (e.g. OAuth callbacks) + if contribs != nil && len(contribs.Gateway.Routes) > 0 { + gatewaySvc["ports"] = []string{"8080:8080"} + } if len(gatewayEnv) > 0 { gatewaySvc["environment"] = gatewayEnv } diff --git a/internal/generate/v1/gateway_config.go b/internal/generate/v1/gateway_config.go index 9789a1c..d296413 100644 --- a/internal/generate/v1/gateway_config.go +++ b/internal/generate/v1/gateway_config.go @@ -62,8 +62,14 @@ type gatewayRuntimeConfig struct { // BuildGatewayConfig merges user gateway config with plugin contributions. func BuildGatewayConfig(cfg *config.Config, contribs *plugin.Contributions) *GatewayConfigOutput { + publicURL := cfg.Gateway.PublicURL + // Default to localhost:8080 when no public_url configured (local dev) + if publicURL == "" { + publicURL = "http://localhost:8080" + } + out := &GatewayConfigOutput{ - PublicURL: cfg.Gateway.PublicURL, + PublicURL: publicURL, } // User-declared services From 96c2e1aa565dd9823fd85f1eae7041869aab4446 Mon Sep 17 00:00:00 2001 From: "dorey-agent[bot]" <3504508+dorey-agent[bot]@users.noreply.github.com> Date: Mon, 8 Jun 2026 15:00:53 +0000 Subject: [PATCH 7/8] feat: auto-contribute gateway services from mcp-oauth providers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Plugin now generates gateway service entries from providers.mcp_url automatically — users don't need to duplicate the URL in gateway.services. Uses Go template range over providers map to emit service entries at generate time. --- core/plugins/mcp-oauth/plugin.yaml | 6 +++++- examples/local-coding/agent.yaml | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/core/plugins/mcp-oauth/plugin.yaml b/core/plugins/mcp-oauth/plugin.yaml index 224f985..2c2d21e 100644 --- a/core/plugins/mcp-oauth/plugin.yaml +++ b/core/plugins/mcp-oauth/plugin.yaml @@ -3,7 +3,7 @@ options: providers: type: object required: true - description: "Map of provider name to MCP config (each needs mcp_url, client_id, client_secret)" + description: "Map of provider name to MCP config (each needs at least mcp_url)" token_dir: type: string required: false @@ -12,6 +12,10 @@ options: contributes: gateway: + services: +{{- range $name, $cfg := .plugin.options.providers }} + - url: "{{ index $cfg "mcp_url" }}" +{{- end }} volumes: - "oauth-tokens:{{ .plugin.options.token_dir }}" middlewares: diff --git a/examples/local-coding/agent.yaml b/examples/local-coding/agent.yaml index aec2924..9ce4d74 100644 --- a/examples/local-coding/agent.yaml +++ b/examples/local-coding/agent.yaml @@ -10,7 +10,6 @@ gateway: - url: https://agent-gateway.stx-ai.net headers: Authorization: Bearer ${STX_LLM_GATEWAY_API_KEY} - - url: https://mcp.notion.com installations: - plugin: "@builtin/home-override" options: From 699dcddc819a9d8e8c40e38e13f5731f9946b893 Mon Sep 17 00:00:00 2001 From: "dorey-agent[bot]" <3504508+dorey-agent[bot]@users.noreply.github.com> Date: Mon, 8 Jun 2026 15:06:02 +0000 Subject: [PATCH 8/8] fix: register per-provider middleware with domain scoping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each provider gets its own middleware scoped to its mcp_url domain. Request to mcp.notion.com → notion's authorize URL. Request to mcp.datadog.com → datadog's authorize URL. Previously used a single middleware with a hardcoded default provider, which returned the wrong authorize URL for multi-provider configs. --- core/plugins/mcp-oauth/middlewares/oauth.go | 89 +++++++++++---------- 1 file changed, 47 insertions(+), 42 deletions(-) diff --git a/core/plugins/mcp-oauth/middlewares/oauth.go b/core/plugins/mcp-oauth/middlewares/oauth.go index ee8f812..505581e 100644 --- a/core/plugins/mcp-oauth/middlewares/oauth.go +++ b/core/plugins/mcp-oauth/middlewares/oauth.go @@ -14,7 +14,6 @@ import ( "net/http" "net/url" "os" - "sort" "strings" "sync" "time" @@ -57,7 +56,6 @@ func init() { tokenDir := "{{ .options.token_dir }}" callbackURL := "{{ .options.callback_url }}" - domains := strings.Split("{{ .domainsList }}", ",") providers := make(map[string]oauthProviderInfo) var rawProviders map[string]map[string]any @@ -97,17 +95,7 @@ func init() { } } - var defaultProvider string - if len(providers) > 0 { - names := make([]string, 0, len(providers)) - for name := range providers { - names = append(names, name) - } - sort.Strings(names) - defaultProvider = names[0] - } - - tokenFile := tokenDir + "/" + defaultProvider + ".json" + // Register one middleware per provider, scoped to that provider's domain state := &oauthState{ tokenDir: tokenDir, providers: providers, @@ -117,39 +105,56 @@ func init() { }, } - if _, err := os.Stat(tokenFile); err != nil { - slog.Warn("oauth: token file not found at startup", "path", tokenFile) - } - for _, s := range oauthSecrets(tokenFile) { - gateway.RegisterSecret(s) - } + for name, p := range providers { + providerName := name + provider := p + providerTokenFile := tokenDir + "/" + providerName + ".json" - gateway.RegisterMiddleware(gateway.MiddlewareDef{ - Name: "oauth:" + domains[0], - Domains: domains, - Func: func(ctx *gateway.MiddlewareContext) error { - token, err := state.getValidToken(tokenFile) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - if p, ok := state.providers[defaultProvider]; ok && p.AuthorizeEndpoint != "" { - authorizeURL := mwBuildAuthorizeURL(p, defaultProvider, callbackURL) - ctx.SetAbortHeader("X-OAuth-Authorize-URL", authorizeURL) - ctx.SetAbortHeader("Content-Type", "application/json") - ctx.Abort(http.StatusUnauthorized, fmt.Sprintf( - `{"error":"oauth_required","provider":%q,"authorize_url":%q}`, - defaultProvider, authorizeURL)) - return nil + // Extract domain from mcp_url for this provider + var providerDomain string + if raw, ok := rawProviders[providerName]; ok { + if mcpURL, ok := raw["mcp_url"].(string); ok { + if u, err := url.Parse(mcpURL); err == nil { + providerDomain = u.Hostname() + } + } + } + if providerDomain == "" { + continue + } + + // Register secrets for this provider's token + for _, s := range oauthSecrets(providerTokenFile) { + gateway.RegisterSecret(s) + } + + gateway.RegisterMiddleware(gateway.MiddlewareDef{ + Name: "oauth:" + providerName, + Domains: []string{providerDomain}, + Func: func(ctx *gateway.MiddlewareContext) error { + token, err := state.getValidToken(providerTokenFile) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + if provider.AuthorizeEndpoint != "" { + authorizeURL := mwBuildAuthorizeURL(provider, providerName, callbackURL) + ctx.SetAbortHeader("X-OAuth-Authorize-URL", authorizeURL) + ctx.SetAbortHeader("Content-Type", "application/json") + ctx.Abort(http.StatusUnauthorized, fmt.Sprintf( + `{"error":"oauth_required","provider":%q,"authorize_url":%q}`, + providerName, authorizeURL)) + return nil + } + slog.Debug("oauth: token file not found", "file", providerTokenFile) + } else { + slog.Error("oauth: failed to get token", "provider", providerName, "error", err) } - slog.Debug("oauth: token file not found", "file", tokenFile) - } else { - slog.Error("oauth: failed to get token", "error", err) + return nil } + ctx.Request.Header.Set("Authorization", "Bearer "+token) return nil - } - ctx.Request.Header.Set("Authorization", "Bearer "+token) - return nil - }, - }) + }, + }) + } } // resolveProvider does OAuth metadata discovery + dynamic client registration.