diff --git a/core/gateway/cmd/gateway/main.go b/core/gateway/cmd/gateway/main.go index 6df3a9a..4c85f37 100644 --- a/core/gateway/cmd/gateway/main.go +++ b/core/gateway/cmd/gateway/main.go @@ -120,7 +120,7 @@ func main() { }() } - // Health endpoint + // Health + route handler endpoint healthAddr := ":8080" go func() { mux := http.NewServeMux() @@ -128,6 +128,15 @@ func main() { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) }) + // Serve plugin-registered routes (e.g. /plugins/mcp-oauth/callback) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + handler := gateway.MatchRoute(r.URL.Path) + if handler != nil { + handler(w, r) + return + } + http.NotFound(w, r) + }) if err := http.ListenAndServe(healthAddr, mux); err != nil { slog.Error("health server error", "error", err) } diff --git a/core/plugins/mcp-oauth/README.md b/core/plugins/mcp-oauth/README.md index a12e07b..962c958 100644 --- a/core/plugins/mcp-oauth/README.md +++ b/core/plugins/mcp-oauth/README.md @@ -1,45 +1,84 @@ # mcp-oauth -Provides OAuth token storage for MCP (Model Context Protocol) providers via a shared volume between the gateway and agent. +Provides full OAuth lifecycle for MCP (Model Context Protocol) providers: automatic token injection, refresh, and browser-based authorization via gateway callback. ## How It Works -Declares a named volume (`oauth-tokens`) mounted into both the gateway and agent containers at a configurable path. MCP providers that require OAuth can store and read token files from this shared location. - -> **Note:** This plugin currently only handles the volume declaration. You must manually declare gateway services for each MCP provider endpoint in your `agent.yaml`. Full automation (dynamic service entries from provider URLs) is planned. +1. **Middleware** intercepts requests to configured domains. If a valid token exists, injects `Authorization: Bearer `. If no token exists, returns 401 with an `authorize_url` for the user to click. +2. **Callback handler** at `/plugins/mcp-oauth/callback` receives the OAuth authorization code, exchanges it for tokens, and writes the token file to the shared volume. +3. **Shared volume** (`oauth-tokens`) is mounted into both gateway and agent containers so the MCP client can read tokens written by the gateway. ## Usage ```yaml # agent.yaml +gateway: + public_url: "https://gateway.myagent.example.com" + services: + - url: https://mcp.notion.com + installations: - plugin: "@builtin/mcp-oauth" options: providers: + # Dynamic mode: just provide mcp_url — credentials auto-discovered notion: mcp_url: https://mcp.notion.com/mcp - token_dir: "/data/oauth-tokens" -gateway: - services: - - url: https://mcp.notion.com - headers: - Authorization: Bearer ${NOTION_TOKEN} -``` - -```bash -# .env -NOTION_TOKEN=ntn_xxxx + # Static mode: provide all OAuth details manually + custom-provider: + mcp_url: https://custom.example.com/mcp + authorize_endpoint: https://custom.example.com/oauth/authorize + token_endpoint: https://custom.example.com/oauth/token + client_id: "your-client-id" + client_secret: "your-client-secret" + scopes: "read_content" + token_dir: "/data/oauth-tokens" ``` ## Options | Option | Type | Required | Default | Description | |--------|------|----------|---------|-------------| -| `providers` | object | yes | — | Map of provider name to config. Each provider needs at least `mcp_url`. | +| `providers` | object | yes | — | Map of provider name to OAuth config | | `token_dir` | string | no | `/data/oauth-tokens` | Directory for OAuth token files | +### Provider Config + +Each provider entry supports two modes: + +**Dynamic mode** (recommended for MCP servers that support RFC 7591): + +| Field | Required | Description | +|-------|----------|-------------| +| `mcp_url` | yes | MCP server endpoint — metadata + registration auto-discovered | + +**Static mode** (for providers without dynamic registration): + +| Field | Required | Description | +|-------|----------|-------------| +| `mcp_url` | yes | MCP server endpoint | +| `authorize_endpoint` | yes | OAuth authorize URL | +| `token_endpoint` | yes | OAuth token exchange URL | +| `client_id` | yes | OAuth client ID | +| `client_secret` | no | OAuth client secret | +| `scopes` | no | Space-separated scopes | + +Mode is auto-detected: if `client_id` is absent, dynamic mode is used. + ## What It Contributes -- **Gateway:** Shared volume `oauth-tokens` mounted at `token_dir` -- **Agent:** Same volume accessible for MCP client token reads +- **Gateway middleware:** Token injection + 401 with authorize URL when unauthenticated +- **Gateway route:** `/plugins/mcp-oauth/callback` — OAuth code exchange handler +- **Gateway volume:** Shared `oauth-tokens` volume at `token_dir` + +## OAuth Flow + +``` +1. Agent MCP client → request to notion domain +2. Gateway middleware: no token file → returns 401 + authorize_url +3. User clicks authorize_url → Notion login page +4. Notion redirects → https://gateway.example.com/plugins/mcp-oauth/callback?code=X&state=notion +5. Gateway callback handler: exchanges code → writes /data/oauth-tokens/notion.json +6. Next request → middleware reads token → injects Bearer header → proxied to Notion +``` diff --git a/core/plugins/mcp-oauth/handlers/callback.go b/core/plugins/mcp-oauth/handlers/callback.go new file mode 100644 index 0000000..bd97861 --- /dev/null +++ b/core/plugins/mcp-oauth/handlers/callback.go @@ -0,0 +1,208 @@ +package custom + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "html" + "io" + "log/slog" + "net" + "net/http" + "net/url" + "os" + "strings" + "time" + + "github.com/donbader/agent-sandbox/core/sdk/gateway" +) + +type oauthCallbackConfig struct { + TokenEndpoint string + ClientID string + ClientSecret string +} + +var ( + oauthCallbackProviders = map[string]oauthCallbackConfig{} + oauthCallbackTokenDir string + oauthCallbackHMACKey []byte +) + +func init() { + oauthCallbackTokenDir = "{{ .options.token_dir }}" + providersJSON := `{{ toJSON .options.providers }}` + + // Derive HMAC key from providers config (same derivation as middleware) + h := sha256.Sum256([]byte(providersJSON)) + oauthCallbackHMACKey = h[:] + + var providers map[string]map[string]any + if err := json.Unmarshal([]byte(providersJSON), &providers); err != nil { + slog.Error("oauth-callback: failed to parse providers", "error", err) + } else { + for name, cfg := range providers { + p := oauthCallbackConfig{} + if v, ok := cfg["token_endpoint"].(string); ok { + p.TokenEndpoint = v + } + if v, ok := cfg["client_id"].(string); ok { + p.ClientID = v + } + if v, ok := cfg["client_secret"].(string); ok { + p.ClientSecret = v + } + oauthCallbackProviders[name] = p + } + } + gateway.RegisterRoute(gateway.RouteDef{ + Path: "{{ .path }}", + Handler: handleOAuthCallback, + }) +} + +func handleOAuthCallback(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + if code == "" { + http.Error(w, "missing code parameter", http.StatusBadRequest) + return + } + if state == "" { + http.Error(w, "missing state parameter", http.StatusBadRequest) + return + } + // Validate HMAC state: format is "hmac_sig:provider_name" + parts := strings.SplitN(state, ":", 2) + if len(parts) != 2 { + http.Error(w, "invalid state format", http.StatusForbidden) + return + } + sig, providerName := parts[0], parts[1] + mac := hmac.New(sha256.New, oauthCallbackHMACKey) + mac.Write([]byte(providerName)) + expectedSig := hex.EncodeToString(mac.Sum(nil))[:16] + if !hmac.Equal([]byte(sig), []byte(expectedSig)) { + http.Error(w, "invalid state signature", http.StatusForbidden) + return + } + provider, ok := oauthCallbackProviders[providerName] + if !ok { + http.Error(w, "unknown provider", http.StatusBadRequest) + return + } + if provider.TokenEndpoint == "" { + http.Error(w, "provider not configured", http.StatusInternalServerError) + return + } + redirectURI := "{{ .public_url }}{{ .path }}" + token, err := exchangeCodeForToken(provider, code, redirectURI) + if err != nil { + slog.Error("oauth-callback: token exchange failed", "error", err) + http.Error(w, "token exchange failed", http.StatusInternalServerError) + return + } + tokenFile := oauthCallbackTokenDir + "/" + providerName + ".json" + if err := writeOAuthToken(tokenFile, token, provider); err != nil { + slog.Error("oauth-callback: failed to save token", "error", err) + http.Error(w, "failed to save token", http.StatusInternalServerError) + return + } + gateway.RegisterSecret(token.AccessToken) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = fmt.Fprintf(w, ` +

Authorization successful

+

Provider %s connected. You can close this tab.

+`, html.EscapeString(providerName)) +} + +type oauthTokenExchangeResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` +} + +func exchangeCodeForToken(provider oauthCallbackConfig, code, redirectURI string) (*oauthTokenExchangeResponse, error) { + u, err := url.Parse(provider.TokenEndpoint) + if err != nil { + return nil, fmt.Errorf("invalid token_endpoint: %w", err) + } + if u.Scheme != "https" { + return nil, fmt.Errorf("token_endpoint must use https, got %q", u.Scheme) + } + params := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "client_id": {provider.ClientID}, + "redirect_uri": {redirectURI}, + } + if provider.ClientSecret != "" { + params.Set("client_secret", provider.ClientSecret) + } + client := &http.Client{Timeout: 30 * time.Second, Transport: cbSSRFSafe()} + resp, err := client.Post(provider.TokenEndpoint, "application/x-www-form-urlencoded", strings.NewReader(params.Encode())) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("reading response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token endpoint returned %d", resp.StatusCode) + } + var tr oauthTokenExchangeResponse + if err := json.Unmarshal(body, &tr); err != nil { + return nil, fmt.Errorf("parsing response: %w", err) + } + return &tr, nil +} + +func writeOAuthToken(path string, token *oauthTokenExchangeResponse, _ oauthCallbackConfig) error { + expiresIn := token.ExpiresIn + if expiresIn == 0 { + expiresIn = 3600 + } + stored := map[string]any{ + "access_token": token.AccessToken, + "expires_at": time.Now().Unix() + expiresIn, + } + if token.RefreshToken != "" { + stored["refresh_token"] = token.RefreshToken + } + data, err := json.MarshalIndent(stored, "", " ") + if err != nil { + return err + } + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, 0600); err != nil { + return err + } + return os.Rename(tmp, path) +} + +func cbSSRFSafe() *http.Transport { + return &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("invalid address %q: %w", addr, err) + } + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("DNS lookup failed for %q: %w", host, err) + } + for _, ip := range ips { + if ip.IP.IsLoopback() || ip.IP.IsPrivate() || ip.IP.IsLinkLocalUnicast() || ip.IP.IsLinkLocalMulticast() { + return nil, fmt.Errorf("refusing to connect to private IP %s", ip.IP) + } + } + dialer := &net.Dialer{Timeout: 10 * time.Second} + return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port)) + }, + } +} diff --git a/core/plugins/mcp-oauth/middlewares/oauth.go b/core/plugins/mcp-oauth/middlewares/oauth.go index 6b15299..505581e 100644 --- a/core/plugins/mcp-oauth/middlewares/oauth.go +++ b/core/plugins/mcp-oauth/middlewares/oauth.go @@ -2,6 +2,9 @@ package custom import ( "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -24,53 +27,215 @@ type storedToken struct { ExpiresAt int64 `json:"expires_at"` TokenEndpoint string `json:"token_endpoint"` ClientID string `json:"client_id"` - ClientSecret *string `json:"client_secret"` +} + +type oauthProviderInfo struct { + AuthorizeEndpoint string + TokenEndpoint string + ClientID string + ClientSecret string + Scopes string + Dynamic bool } type oauthState struct { - tokenFile string + tokenDir string + providers map[string]oauthProviderInfo mu sync.Mutex cachedToken *storedToken cachedUntil time.Time httpClient *http.Client } +var oauthHMACKey []byte + func init() { - tokenFile := "{{ .options.token_file }}" - domains := strings.Split("{{ .domainsList }}", ",") + providersJSON := `{{ toJSON .options.providers }}` + h := sha256.Sum256([]byte(providersJSON)) + oauthHMACKey = h[:] + + tokenDir := "{{ .options.token_dir }}" + callbackURL := "{{ .options.callback_url }}" + + providers := make(map[string]oauthProviderInfo) + var rawProviders map[string]map[string]any + if err := json.Unmarshal([]byte(providersJSON), &rawProviders); err != nil { + slog.Error("oauth: failed to parse providers config", "error", err) + } else { + for name, cfg := range rawProviders { + p := oauthProviderInfo{} + if v, ok := cfg["authorize_endpoint"].(string); ok { + p.AuthorizeEndpoint = v + } + if v, ok := cfg["token_endpoint"].(string); ok { + p.TokenEndpoint = v + } + if v, ok := cfg["client_id"].(string); ok { + p.ClientID = v + } + if v, ok := cfg["client_secret"].(string); ok { + p.ClientSecret = v + } + if v, ok := cfg["scopes"].(string); ok { + p.Scopes = v + } + // Dynamic mode: no client_id means we need discovery+registration + if p.ClientID == "" { + p.Dynamic = true + mcpURL, _ := cfg["mcp_url"].(string) + resolved := resolveProvider(mcpURL, callbackURL, tokenDir, name) + if resolved != nil { + p.AuthorizeEndpoint = resolved.AuthorizeEndpoint + p.TokenEndpoint = resolved.TokenEndpoint + p.ClientID = resolved.ClientID + p.ClientSecret = resolved.ClientSecret + } + } + providers[name] = p + } + } + // Register one middleware per provider, scoped to that provider's domain state := &oauthState{ - tokenFile: tokenFile, + tokenDir: tokenDir, + providers: providers, httpClient: &http.Client{ Timeout: 30 * time.Second, Transport: oauthSSRFSafeTransport(), }, } - if _, err := os.Stat(tokenFile); err != nil { - slog.Warn("oauth token file not found at startup", "path", tokenFile, "error", err) - } - for _, s := range oauthSecrets(tokenFile) { - gateway.RegisterSecret(s) - } + for name, p := range providers { + providerName := name + provider := p + providerTokenFile := tokenDir + "/" + providerName + ".json" - gateway.RegisterMiddleware(gateway.MiddlewareDef{ - Name: "oauth:" + domains[0], - Domains: domains, - Func: func(ctx *gateway.MiddlewareContext) error { - token, err := state.getValidToken() - if err != nil { - if errors.Is(err, os.ErrNotExist) { - slog.Debug("oauth: token file not found", "file", state.tokenFile) - } else { - slog.Error("oauth: failed to get token", "error", err) + // Extract domain from mcp_url for this provider + var providerDomain string + if raw, ok := rawProviders[providerName]; ok { + if mcpURL, ok := raw["mcp_url"].(string); ok { + if u, err := url.Parse(mcpURL); err == nil { + providerDomain = u.Hostname() + } + } + } + if providerDomain == "" { + continue + } + + // Register secrets for this provider's token + for _, s := range oauthSecrets(providerTokenFile) { + gateway.RegisterSecret(s) + } + + gateway.RegisterMiddleware(gateway.MiddlewareDef{ + Name: "oauth:" + providerName, + Domains: []string{providerDomain}, + Func: func(ctx *gateway.MiddlewareContext) error { + token, err := state.getValidToken(providerTokenFile) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + if provider.AuthorizeEndpoint != "" { + authorizeURL := mwBuildAuthorizeURL(provider, providerName, callbackURL) + ctx.SetAbortHeader("X-OAuth-Authorize-URL", authorizeURL) + ctx.SetAbortHeader("Content-Type", "application/json") + ctx.Abort(http.StatusUnauthorized, fmt.Sprintf( + `{"error":"oauth_required","provider":%q,"authorize_url":%q}`, + providerName, authorizeURL)) + return nil + } + slog.Debug("oauth: token file not found", "file", providerTokenFile) + } else { + slog.Error("oauth: failed to get token", "provider", providerName, "error", err) + } + return nil } + ctx.Request.Header.Set("Authorization", "Bearer "+token) return nil + }, + }) + } +} + +// resolveProvider does OAuth metadata discovery + dynamic client registration. +// Returns nil if discovery fails (provider will be skipped). +func resolveProvider(mcpURL, callbackURL, tokenDir, name string) *oauthProviderInfo { + if mcpURL == "" { + slog.Error("oauth: dynamic provider has no mcp_url", "provider", name) + return nil + } + + // Check for cached registration + regFile := tokenDir + "/" + name + ".reg.json" + if data, err := os.ReadFile(regFile); err == nil { + var cached struct { + AuthorizeEndpoint string `json:"authorize_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + } + if json.Unmarshal(data, &cached) == nil && cached.ClientID != "" { + slog.Info("oauth: using cached dynamic registration", "provider", name) + return &oauthProviderInfo{ + AuthorizeEndpoint: cached.AuthorizeEndpoint, + TokenEndpoint: cached.TokenEndpoint, + ClientID: cached.ClientID, + ClientSecret: cached.ClientSecret, } - ctx.Request.Header.Set("Authorization", "Bearer "+token) - return nil - }, - }) + } + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + meta, err := gateway.DiscoverOAuthMetadata(ctx, mcpURL) + if err != nil { + slog.Error("oauth: metadata discovery failed", "provider", name, "error", err) + return nil + } + + reg, err := gateway.RegisterOAuthClient(ctx, meta.RegistrationEndpoint, []string{callbackURL}, "agent-sandbox:"+name) + if err != nil { + slog.Error("oauth: dynamic registration failed", "provider", name, "error", err) + return nil + } + + // Persist registration for reuse across restarts + regData, _ := json.MarshalIndent(map[string]string{ + "authorize_endpoint": meta.AuthorizationEndpoint, + "token_endpoint": meta.TokenEndpoint, + "client_id": reg.ClientID, + "client_secret": reg.ClientSecret, + }, "", " ") + if err := os.WriteFile(regFile, regData, 0600); err != nil { + slog.Warn("oauth: failed to cache registration", "provider", name, "error", err) + } + + slog.Info("oauth: dynamic registration complete", "provider", name, "client_id", reg.ClientID) + return &oauthProviderInfo{ + AuthorizeEndpoint: meta.AuthorizationEndpoint, + TokenEndpoint: meta.TokenEndpoint, + ClientID: reg.ClientID, + ClientSecret: reg.ClientSecret, + } +} + +func mwBuildAuthorizeURL(provider oauthProviderInfo, providerName, callbackURL string) string { + mac := hmac.New(sha256.New, oauthHMACKey) + mac.Write([]byte(providerName)) + sig := hex.EncodeToString(mac.Sum(nil))[:16] + state := sig + ":" + providerName + + params := url.Values{ + "client_id": {provider.ClientID}, + "response_type": {"code"}, + "state": {state}, + "redirect_uri": {callbackURL}, + } + if provider.Scopes != "" { + params.Set("scope", provider.Scopes) + } + return provider.AuthorizeEndpoint + "?" + params.Encode() } func oauthSecrets(tokenFile string) []string { @@ -88,13 +253,13 @@ func oauthSecrets(tokenFile string) []string { return nil } -func (s *oauthState) getValidToken() (string, error) { +func (s *oauthState) getValidToken(tokenFile string) (string, error) { s.mu.Lock() defer s.mu.Unlock() if s.cachedToken != nil && time.Now().Before(s.cachedUntil) { return s.cachedToken.AccessToken, nil } - stored, err := s.readTokenFile() + stored, err := s.readTokenFile(tokenFile) if err != nil { return "", err } @@ -105,7 +270,7 @@ func (s *oauthState) getValidToken() (string, error) { return "", fmt.Errorf("token refresh failed: %w", err) } stored = refreshed - if err := s.writeTokenFile(stored); err != nil { + if err := s.writeTokenFile(tokenFile, stored); err != nil { slog.Error("oauth: failed to write refreshed token", "error", err) } } @@ -119,28 +284,28 @@ func (s *oauthState) getValidToken() (string, error) { return stored.AccessToken, nil } -func (s *oauthState) readTokenFile() (*storedToken, error) { - data, err := os.ReadFile(s.tokenFile) +func (s *oauthState) readTokenFile(tokenFile string) (*storedToken, error) { + data, err := os.ReadFile(tokenFile) if err != nil { - return nil, fmt.Errorf("reading token file %s: %w", s.tokenFile, err) + return nil, fmt.Errorf("reading token file %s: %w", tokenFile, err) } var token storedToken if err := json.Unmarshal(data, &token); err != nil { - return nil, fmt.Errorf("parsing token file %s: %w", s.tokenFile, err) + return nil, fmt.Errorf("parsing token file %s: %w", tokenFile, err) } return &token, nil } -func (s *oauthState) writeTokenFile(token *storedToken) error { +func (s *oauthState) writeTokenFile(tokenFile string, token *storedToken) error { data, err := json.MarshalIndent(token, "", " ") if err != nil { return err } - tmp := s.tokenFile + ".tmp" + tmp := tokenFile + ".tmp" if err := os.WriteFile(tmp, data, 0600); err != nil { return err } - return os.Rename(tmp, s.tokenFile) + return os.Rename(tmp, tokenFile) } type oauthTokenResponse struct { @@ -165,16 +330,13 @@ func (s *oauthState) refreshToken(stored *storedToken) (*storedToken, error) { "refresh_token": {*stored.RefreshToken}, "client_id": {stored.ClientID}, } - if stored.ClientSecret != nil && *stored.ClientSecret != "" { - params.Set("client_secret", *stored.ClientSecret) - } resp, err := s.httpClient.Post( stored.TokenEndpoint, "application/x-www-form-urlencoded", strings.NewReader(params.Encode()), ) if err != nil { - return nil, fmt.Errorf("refresh request to %s: %w", stored.TokenEndpoint, err) + return nil, fmt.Errorf("refresh request: %w", err) } defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) @@ -182,7 +344,7 @@ func (s *oauthState) refreshToken(stored *storedToken) (*storedToken, error) { return nil, fmt.Errorf("reading refresh response: %w", err) } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("token refresh returned %d: %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("token refresh returned %d", resp.StatusCode) } var tr oauthTokenResponse if err := json.Unmarshal(body, &tr); err != nil { @@ -202,7 +364,6 @@ func (s *oauthState) refreshToken(stored *storedToken) (*storedToken, error) { ExpiresAt: time.Now().Unix() + expiresIn, TokenEndpoint: stored.TokenEndpoint, ClientID: stored.ClientID, - ClientSecret: stored.ClientSecret, }, nil } @@ -219,7 +380,7 @@ func oauthSSRFSafeTransport() *http.Transport { } for _, ip := range ips { if ip.IP.IsLoopback() || ip.IP.IsPrivate() || ip.IP.IsLinkLocalUnicast() || ip.IP.IsLinkLocalMulticast() { - return nil, fmt.Errorf("oauth: refusing to connect to private IP %s (resolved from %s)", ip.IP, host) + return nil, fmt.Errorf("oauth: refusing to connect to private IP %s", ip.IP) } } dialer := &net.Dialer{Timeout: 10 * time.Second} diff --git a/core/plugins/mcp-oauth/plugin.yaml b/core/plugins/mcp-oauth/plugin.yaml index a7f8970..2c2d21e 100644 --- a/core/plugins/mcp-oauth/plugin.yaml +++ b/core/plugins/mcp-oauth/plugin.yaml @@ -3,21 +3,23 @@ options: providers: type: object required: true - description: "Map of provider name to MCP config" + description: "Map of provider name to MCP config (each needs at least mcp_url)" token_dir: type: string required: false default: "/data/oauth-tokens" description: "Directory for OAuth token files" - token_file: - type: string - required: false - default: "/data/oauth-tokens/token.json" - description: "Path to OAuth token JSON file" contributes: gateway: + services: +{{- range $name, $cfg := .plugin.options.providers }} + - url: "{{ index $cfg "mcp_url" }}" +{{- end }} volumes: - "oauth-tokens:{{ .plugin.options.token_dir }}" middlewares: - custom: "./middlewares/oauth.go" + routes: + - path: "/callback" + handler: "./handlers/callback.go" diff --git a/core/sdk/gateway/middleware.go b/core/sdk/gateway/middleware.go index 3fe59bc..cadb75f 100644 --- a/core/sdk/gateway/middleware.go +++ b/core/sdk/gateway/middleware.go @@ -3,12 +3,33 @@ package gateway import ( "net" "net/http" + "strings" ) // MiddlewareContext provides request access and environment resolution for custom middleware. type MiddlewareContext struct { Request *http.Request Env func(string) string + + // Abort fields: if AbortStatus is set (non-zero), the gateway returns this + // response instead of proxying the request to the upstream. + AbortStatus int + AbortHeaders http.Header + AbortBody string +} + +// Abort sets the context to return an HTTP response instead of proxying. +func (c *MiddlewareContext) Abort(status int, body string) { + c.AbortStatus = status + c.AbortBody = body +} + +// SetAbortHeader sets a header on the abort response. +func (c *MiddlewareContext) SetAbortHeader(key, value string) { + if c.AbortHeaders == nil { + c.AbortHeaders = make(http.Header) + } + c.AbortHeaders.Set(key, value) } // MiddlewareFunc is the signature for custom gateway middleware. @@ -80,4 +101,39 @@ func Secrets() []string { func ResetForTesting() { registry = nil secrets = nil + routeRegistry = nil +} + +// RouteHandlerFunc is the signature for custom gateway route handlers. +// Unlike middleware (which intercepts proxy requests), route handlers serve +// direct HTTP requests to the gateway on registered paths. +type RouteHandlerFunc func(w http.ResponseWriter, r *http.Request) + +// RouteDef defines a route handler with a path pattern. +type RouteDef struct { + // Path is the full namespaced path (e.g. /plugins/mcp-oauth/notion/callback). + Path string + Handler RouteHandlerFunc +} + +var routeRegistry []RouteDef + +// RegisterRoute registers a direct HTTP route handler on the gateway. +func RegisterRoute(def RouteDef) { + routeRegistry = append(routeRegistry, def) +} + +// Routes returns all registered route definitions. +func Routes() []RouteDef { + return routeRegistry +} + +// MatchRoute returns the handler for the given request path, or nil if no route matches. +func MatchRoute(path string) RouteHandlerFunc { + for _, route := range routeRegistry { + if path == route.Path || strings.HasPrefix(path, route.Path+"/") { + return route.Handler + } + } + return nil } diff --git a/core/sdk/gateway/middleware_test.go b/core/sdk/gateway/middleware_test.go index 36368a8..63db982 100644 --- a/core/sdk/gateway/middleware_test.go +++ b/core/sdk/gateway/middleware_test.go @@ -1,6 +1,9 @@ package gateway -import "testing" +import ( + "net/http" + "testing" +) func TestRegisterSecret(t *testing.T) { // Reset state @@ -21,3 +24,75 @@ func TestRegisterSecret(t *testing.T) { t.Errorf("expected 'another-secret', got %q", got[1]) } } + +func TestRegisterRoute(t *testing.T) { + ResetForTesting() + + handler := func(w http.ResponseWriter, r *http.Request) {} + RegisterRoute(RouteDef{Path: "/plugins/mcp-oauth/callback", Handler: handler}) + RegisterRoute(RouteDef{Path: "/plugins/other/webhook", Handler: handler}) + + routes := Routes() + if len(routes) != 2 { + t.Fatalf("expected 2 routes, got %d", len(routes)) + } + if routes[0].Path != "/plugins/mcp-oauth/callback" { + t.Errorf("expected '/plugins/mcp-oauth/callback', got %q", routes[0].Path) + } + if routes[1].Path != "/plugins/other/webhook" { + t.Errorf("expected '/plugins/other/webhook', got %q", routes[1].Path) + } +} + +func TestMatchRoute(t *testing.T) { + ResetForTesting() + + RegisterRoute(RouteDef{ + Path: "/plugins/mcp-oauth/callback", + Handler: func(w http.ResponseWriter, r *http.Request) {}, + }) + + // Exact match + h := MatchRoute("/plugins/mcp-oauth/callback") + if h == nil { + t.Fatal("expected handler for exact path match") + } + + // Prefix match (sub-path) + h = MatchRoute("/plugins/mcp-oauth/callback/extra") + if h == nil { + t.Fatal("expected handler for prefix path match") + } + + // No match + h = MatchRoute("/plugins/other/callback") + if h != nil { + t.Fatal("expected nil for non-matching path") + } + + // No match (partial prefix without /) + h = MatchRoute("/plugins/mcp-oauth/callbackextra") + if h != nil { + t.Fatal("expected nil for partial prefix without /") + } +} + +func TestMiddlewareContext_Abort(t *testing.T) { + ctx := &MiddlewareContext{ + Request: &http.Request{}, + Env: func(s string) string { return "" }, + } + + ctx.SetAbortHeader("X-Test", "value") + ctx.Abort(401, `{"error":"unauthorized"}`) + + if ctx.AbortStatus != 401 { + t.Errorf("expected status 401, got %d", ctx.AbortStatus) + } + if ctx.AbortBody != `{"error":"unauthorized"}` { + t.Errorf("unexpected body: %s", ctx.AbortBody) + } + if ctx.AbortHeaders.Get("X-Test") != "value" { + t.Errorf("expected X-Test header, got %q", ctx.AbortHeaders.Get("X-Test")) + } +} diff --git a/core/sdk/gateway/oauth_discovery.go b/core/sdk/gateway/oauth_discovery.go new file mode 100644 index 0000000..91b826a --- /dev/null +++ b/core/sdk/gateway/oauth_discovery.go @@ -0,0 +1,133 @@ +package gateway + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// OAuthMetadata holds discovered OAuth server metadata (RFC 8414). +type OAuthMetadata struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + RegistrationEndpoint string `json:"registration_endpoint"` + ScopesSupported []string `json:"scopes_supported"` +} + +// OAuthRegistration holds dynamic client registration response (RFC 7591). +type OAuthRegistration struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` + RedirectURIs []string `json:"redirect_uris,omitempty"` +} + +// DiscoverOAuthMetadata fetches OAuth server metadata from the MCP server. +// Tries .well-known/oauth-authorization-server first, falls back to .well-known/openid-configuration. +func DiscoverOAuthMetadata(ctx context.Context, mcpURL string) (*OAuthMetadata, error) { + base, err := url.Parse(mcpURL) + if err != nil { + return nil, fmt.Errorf("parse mcp_url: %w", err) + } + + // Try RFC 8414 path first + wellKnownPaths := []string{ + "/.well-known/oauth-authorization-server", + "/.well-known/openid-configuration", + } + + client := &http.Client{Timeout: 15 * time.Second} + + for _, path := range wellKnownPaths { + metaURL := base.Scheme + "://" + base.Host + path + req, err := http.NewRequestWithContext(ctx, "GET", metaURL, nil) + if err != nil { + continue + } + resp, err := client.Do(req) + if err != nil { + continue + } + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + continue + } + + var meta OAuthMetadata + if err := json.Unmarshal(body, &meta); err != nil { + continue + } + if meta.AuthorizationEndpoint != "" && meta.TokenEndpoint != "" { + return &meta, nil + } + } + + return nil, fmt.Errorf("oauth metadata not found at %s", base.Host) +} + +// RegisterOAuthClient performs Dynamic Client Registration (RFC 7591). +func RegisterOAuthClient(ctx context.Context, registrationEndpoint string, redirectURIs []string, clientName string) (*OAuthRegistration, error) { + if registrationEndpoint == "" { + return nil, fmt.Errorf("no registration_endpoint available") + } + + u, err := url.Parse(registrationEndpoint) + if err != nil { + return nil, fmt.Errorf("invalid registration_endpoint: %w", err) + } + if u.Scheme != "https" { + return nil, fmt.Errorf("registration_endpoint must use https, got %q", u.Scheme) + } + + reqBody := map[string]any{ + "client_name": clientName, + "redirect_uris": redirectURIs, + "grant_types": []string{"authorization_code", "refresh_token"}, + "response_types": []string{"code"}, + "token_endpoint_auth_method": "client_secret_post", + } + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal registration request: %w", err) + } + + client := &http.Client{Timeout: 15 * time.Second} + req, err := http.NewRequestWithContext(ctx, "POST", registrationEndpoint, strings.NewReader(string(bodyBytes))) + if err != nil { + return nil, fmt.Errorf("create registration request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("registration request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("reading registration response: %w", err) + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + return nil, fmt.Errorf("registration returned %d: %s", resp.StatusCode, string(body)) + } + + var reg OAuthRegistration + if err := json.Unmarshal(body, ®); err != nil { + return nil, fmt.Errorf("parse registration response: %w", err) + } + if reg.ClientID == "" { + return nil, fmt.Errorf("registration response missing client_id") + } + + return ®, nil +} diff --git a/docs/configuration.md b/docs/configuration.md index dd5945f..9a58d07 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -50,6 +50,7 @@ runtime: - "./local:/home/agent/local" gateway: + public_url: string # optional — public URL for OAuth callbacks and webhook receivers services: # optional — external services proxied through the gateway - url: https://api.example.com network: string # optional — compose network to attach diff --git a/docs/guides/creating-plugins.md b/docs/guides/creating-plugins.md index 8389c6f..0efac82 100644 --- a/docs/guides/creating-plugins.md +++ b/docs/guides/creating-plugins.md @@ -100,10 +100,50 @@ contributes: | `runtime.volumes` | Volume mount specs for docker-compose | | `gateway.services` | Services the gateway intercepts. Each entry has a `url` and a list of `middlewares` | | `gateway.services[].middlewares[].custom` | Path to a Go middleware file, relative to the plugin directory | +| `gateway.routes` | HTTP route handlers registered on the gateway (see below) | | `sidecar.services` | Additional Docker containers that run alongside the agent (see below) | All `contributes` fields support Go `text/template` syntax. See **Template Context** below for available variables and functions. +### `contributes.gateway.routes` + +Registers HTTP route handlers directly on the gateway. Routes are automatically namespaced under `/plugins/{plugin-name}/` to prevent conflicts between plugins. + +```yaml +contributes: + gateway: + routes: + - path: "/callback" + handler: "./handlers/callback.go" +``` + +| Field | Description | +|-------|-------------| +| `path` | Relative path (prefixed automatically with `/plugins/{plugin-name}`) | +| `handler` | Path to a Go handler file, relative to the plugin directory | + +The handler file is a Go template (same as middleware) with access to `{{ .path }}` (the full namespaced path) and `{{ .options }}`. It must self-register using `gateway.RegisterRoute()` in an `init()` function: + +```go +package custom + +import ( + "net/http" + "github.com/donbader/agent-sandbox/core/sdk/gateway" +) + +func init() { + gateway.RegisterRoute(gateway.RouteDef{ + Path: "{{ .path }}", + Handler: myHandler, + }) +} + +func myHandler(w http.ResponseWriter, r *http.Request) { + // Handle the request +} +``` + ### `contributes.sidecar.services` Defines additional Docker containers that run alongside the agent. Each sidecar is a separate compose service. Sidecars automatically get `depends_on: { agent: { condition: service_healthy } }` so they start only after the agent is healthy. diff --git a/docs/plugins.md b/docs/plugins.md index 4cbc010..7c35ba9 100644 --- a/docs/plugins.md +++ b/docs/plugins.md @@ -93,24 +93,36 @@ Key-only auth. No passwords. Connect with `ssh -p 2222 agent@localhost`. ### @builtin/mcp-oauth -OAuth token storage for MCP (Model Context Protocol) providers via a shared volume. +Full OAuth lifecycle for MCP providers: token injection, automatic refresh, and browser-based authorization via gateway callback. ```yaml +gateway: + public_url: "https://gateway.myagent.example.com" + services: + - url: https://mcp.notion.com + installations: - plugin: "@builtin/mcp-oauth" options: providers: notion: mcp_url: https://mcp.notion.com/mcp + authorize_endpoint: https://api.notion.com/v1/oauth/authorize + token_endpoint: https://api.notion.com/v1/oauth/token + client_id: "your-client-id" + client_secret: "your-client-secret" + scopes: "read_content" token_dir: "/data/oauth-tokens" ``` | Option | Type | Required | Default | Description | |--------|------|----------|---------|-------------| -| `providers` | object | yes | — | Map of provider name to MCP config (`mcp_url` required per provider) | +| `providers` | object | yes | — | Map of provider name to OAuth config | | `token_dir` | string | no | `/data/oauth-tokens` | Directory for OAuth token files | -You must also declare the provider endpoints as gateway services in your `agent.yaml`. +**Flow:** When no token exists, the middleware returns 401 with an `authorize_url`. The user visits the URL, authorizes, and the OAuth provider redirects to `/plugins/mcp-oauth/callback` on the gateway. The callback handler exchanges the code for tokens and writes them to the shared volume. Subsequent requests are automatically authenticated. + +**Contributes:** gateway middleware (token injection + 401), gateway route (`/plugins/mcp-oauth/callback`), shared volume. --- diff --git a/examples/local-coding/README.md b/examples/local-coding/README.md index 2d90e7a..a5fd204 100644 --- a/examples/local-coding/README.md +++ b/examples/local-coding/README.md @@ -89,3 +89,31 @@ With `volume: true`, the home directory persists across container restarts (shel | Variable | Description | |----------|-------------| | `STX_LLM_GATEWAY_API_KEY` | API key for the LLM gateway | + +## Notion MCP Integration + +This example includes the `mcp-oauth` plugin configured for Notion. On first use: + +1. Start the sandbox: + ```bash + agent-sandbox generate + agent-sandbox compose up --build -d + ``` + +2. When the agent tries to access Notion, the gateway returns a 401 with an authorize URL. The URL is also available at: + ``` + http://localhost:8080/plugins/mcp-oauth/callback + ``` + +3. The first request to Notion's MCP will trigger OAuth discovery and dynamic client registration automatically. The gateway logs the authorize URL: + ```bash + agent-sandbox compose logs coder-gateway | grep authorize_url + ``` + +4. Open the authorize URL in your browser, log in to Notion, and authorize access. + +5. Notion redirects to `http://localhost:8080/plugins/mcp-oauth/callback` — the gateway exchanges the code for tokens and stores them automatically. + +6. Subsequent requests to Notion are authenticated transparently. + +> **Note:** No Notion developer app or API key needed. The plugin uses OAuth Dynamic Client Registration (RFC 7591) — credentials are obtained automatically. diff --git a/examples/local-coding/agent.yaml b/examples/local-coding/agent.yaml index 5ea20c9..9ce4d74 100644 --- a/examples/local-coding/agent.yaml +++ b/examples/local-coding/agent.yaml @@ -15,3 +15,8 @@ installations: options: home_directory: "./home" volume: true + - plugin: "@builtin/mcp-oauth" + options: + providers: + notion: + mcp_url: https://mcp.notion.com/mcp diff --git a/internal/config/config.go b/internal/config/config.go index b87c1fc..013204d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -67,7 +67,8 @@ type RuntimeConfig struct { // GatewayConfig holds gateway proxy configuration. type GatewayConfig struct { - Services []GatewayServiceEntry `yaml:"services" json:"services,omitempty" jsonschema:"title=services,description=External services proxied through the gateway"` + PublicURL string `yaml:"public_url" json:"public_url,omitempty" jsonschema:"title=public_url,description=Public URL of the gateway (used for OAuth callbacks and webhook receivers)"` + Services []GatewayServiceEntry `yaml:"services" json:"services,omitempty" jsonschema:"title=services,description=External services proxied through the gateway"` } // GatewayServiceEntry represents an allowed upstream service. diff --git a/internal/generate/v1/compose.go b/internal/generate/v1/compose.go index 30132d6..329d73a 100644 --- a/internal/generate/v1/compose.go +++ b/internal/generate/v1/compose.go @@ -110,6 +110,10 @@ func BuildCompose(cfg *config.Config, contribs *plugin.Contributions, projectDir "retries": 3, }, } + // Expose gateway HTTP port when plugin routes are registered (e.g. OAuth callbacks) + if contribs != nil && len(contribs.Gateway.Routes) > 0 { + gatewaySvc["ports"] = []string{"8080:8080"} + } if len(gatewayEnv) > 0 { gatewaySvc["environment"] = gatewayEnv } diff --git a/internal/generate/v1/gateway_config.go b/internal/generate/v1/gateway_config.go index 854583b..d296413 100644 --- a/internal/generate/v1/gateway_config.go +++ b/internal/generate/v1/gateway_config.go @@ -19,6 +19,8 @@ type GatewayConfigOutput struct { Services []GatewayServiceOutput Middlewares []MiddlewareRef // custom .go files to copy with domain scope AuthHeaders []AuthHeaderEntry // auth-header entries to generate as .go files + Routes []RouteRef // route handler .go files with namespaced paths + PublicURL string // gateway public URL for callbacks } // MiddlewareRef associates a custom middleware file with its target domains. @@ -27,6 +29,13 @@ type MiddlewareRef struct { Domains []string // domains this middleware applies to } +// RouteRef associates a route handler file with its namespaced path. +type RouteRef struct { + Path string // namespaced URL path (e.g. /plugins/mcp-oauth/callback) + Handler string // path to handler .go file + PluginName string // plugin that contributed this route +} + // AuthHeaderEntry describes an auth-header middleware to generate at build time. type AuthHeaderEntry struct { Domain string @@ -48,11 +57,20 @@ type gatewayRuntimeConfig struct { DNSListen string `yaml:"dns_listen"` MITMDomains []string `yaml:"mitm_domains"` HealthAddr string `yaml:"health_addr,omitempty"` + PublicURL string `yaml:"public_url,omitempty"` } // BuildGatewayConfig merges user gateway config with plugin contributions. func BuildGatewayConfig(cfg *config.Config, contribs *plugin.Contributions) *GatewayConfigOutput { - out := &GatewayConfigOutput{} + publicURL := cfg.Gateway.PublicURL + // Default to localhost:8080 when no public_url configured (local dev) + if publicURL == "" { + publicURL = "http://localhost:8080" + } + + out := &GatewayConfigOutput{ + PublicURL: publicURL, + } // User-declared services for _, svc := range cfg.Gateway.Services { @@ -120,6 +138,7 @@ func WriteGatewayRuntimeConfig(buildDir string, gwCfg *GatewayConfigOutput) erro rc := gatewayRuntimeConfig{ Listen: ":8443", DNSListen: ":53", + PublicURL: gwCfg.PublicURL, } for _, svc := range gwCfg.Services { diff --git a/internal/generate/v1/gateway_config_test.go b/internal/generate/v1/gateway_config_test.go index fb8bc60..0066177 100644 --- a/internal/generate/v1/gateway_config_test.go +++ b/internal/generate/v1/gateway_config_test.go @@ -72,3 +72,15 @@ func TestBuildGatewayConfig_NilContribs(t *testing.T) { gwCfg := BuildGatewayConfig(cfg, nil) require.Len(t, gwCfg.Services, 1) } + +func TestBuildGatewayConfig_PublicURL(t *testing.T) { + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + PublicURL: "https://gateway.example.com", + Services: []config.GatewayServiceEntry{{URL: "https://api.example.com"}}, + }, + } + + gwCfg := BuildGatewayConfig(cfg, nil) + assert.Equal(t, "https://gateway.example.com", gwCfg.PublicURL) +} diff --git a/internal/generate/v1/generator.go b/internal/generate/v1/generator.go index b95f214..7165033 100644 --- a/internal/generate/v1/generator.go +++ b/internal/generate/v1/generator.go @@ -288,11 +288,23 @@ func (g *Generator) generateAgent(cfg *config.Config, agentDir, buildDir string) } gwCfg := BuildGatewayConfig(cfg, merged) + + // Collect routes from each plugin with namespace prefixing + routeRefs, err := g.collectPluginRoutes(resolved, buildDir) + if err != nil { + return nil, err + } + gwCfg.Routes = routeRefs + if err := WriteGatewayRuntimeConfig(buildDir, gwCfg); err != nil { return nil, fmt.Errorf("write gateway runtime config: %w", err) } if len(gwCfg.Middlewares) > 0 { allOpts := collectAllOptions(cfg) + // Inject computed callback URLs for plugins with routes + for _, route := range gwCfg.Routes { + allOpts["callback_url"] = gwCfg.PublicURL + route.Path + } if err := CopyCustomMiddleware(g.projectDir, buildDir, gwCfg.Middlewares, allOpts); err != nil { return nil, fmt.Errorf("copy middleware: %w", err) } @@ -302,6 +314,12 @@ func (g *Generator) generateAgent(cfg *config.Config, agentDir, buildDir string) return nil, fmt.Errorf("generate auth-header middleware: %w", err) } } + if len(gwCfg.Routes) > 0 { + allOpts := collectAllOptions(cfg) + if err := CopyRouteHandlers(g.projectDir, buildDir, gwCfg.Routes, allOpts, gwCfg.PublicURL); err != nil { + return nil, fmt.Errorf("copy route handlers: %w", err) + } + } return &AgentResult{Config: cfg, Contribs: merged, BuildDir: buildDir}, nil } diff --git a/internal/generate/v1/middleware_copy.go b/internal/generate/v1/middleware_copy.go index 4980fda..cc1c047 100644 --- a/internal/generate/v1/middleware_copy.go +++ b/internal/generate/v1/middleware_copy.go @@ -3,6 +3,7 @@ package v1 import ( "bytes" "encoding/base64" + "encoding/json" "fmt" "os" "path/filepath" @@ -134,7 +135,17 @@ func renderMiddleware(name, content string, data map[string]any) (string, error) return content, nil } - tmpl, err := template.New(name).Parse(content) + funcMap := template.FuncMap{ + "toJSON": func(v any) (string, error) { + b, err := json.Marshal(v) + if err != nil { + return "", fmt.Errorf("toJSON: %w", err) + } + return string(b), nil + }, + } + + tmpl, err := template.New(name).Funcs(funcMap).Parse(content) if err != nil { return "", fmt.Errorf("parse template: %w", err) } diff --git a/internal/generate/v1/routes.go b/internal/generate/v1/routes.go new file mode 100644 index 0000000..632e2a4 --- /dev/null +++ b/internal/generate/v1/routes.go @@ -0,0 +1,153 @@ +package v1 + +import ( + "bytes" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "text/template" +) + +// collectPluginRoutes iterates over resolved plugins, namespaces their declared routes, +// checks for conflicts, and extracts/resolves handler file paths. +// Returns route refs ready for code generation. +func (g *Generator) collectPluginRoutes(resolved map[string]*resolvedPlugin, buildDir string) ([]RouteRef, error) { + var routes []RouteRef + seen := make(map[string]string) // full path → plugin name (for conflict detection) + + for ref, rp := range resolved { + for _, route := range rp.rendered.Gateway.Routes { + if route.Path == "" || route.Handler == "" { + return nil, fmt.Errorf("plugin %q: route entry must have both path and handler", ref) + } + + // Namespace: /plugins/{plugin-name}{declared-path} + namespacedPath := "/plugins/" + rp.def.Name + normalizePath(route.Path) + + // Conflict detection + if owner, exists := seen[namespacedPath]; exists { + return nil, fmt.Errorf("route conflict: path %q registered by both %q and %q", namespacedPath, owner, ref) + } + seen[namespacedPath] = ref + + // Resolve handler file path + handlerPath, err := g.resolveRouteHandler(rp.def.Name, route.Handler, rp.def.BaseDir, buildDir) + if err != nil { + return nil, fmt.Errorf("plugin %q route %q: %w", ref, route.Path, err) + } + + routes = append(routes, RouteRef{ + Path: namespacedPath, + Handler: handlerPath, + PluginName: rp.def.Name, + }) + } + } + + return routes, nil +} + +// resolveRouteHandler resolves a route handler file path. +// For local plugins (baseDir != ""), it's relative to the plugin directory. +// For bundled plugins, it's extracted from the bundled FS. +func (g *Generator) resolveRouteHandler(pluginName, handler, baseDir, buildDir string) (string, error) { + if baseDir != "" { + return filepath.Join(baseDir, handler), nil + } + + // Bundled plugin — extract handler from FS + if g.bundledFS == nil { + return "", fmt.Errorf("no bundled FS available to extract handler %q", handler) + } + return g.extractBundledMiddleware(pluginName, handler, buildDir) +} + +// CopyRouteHandlers copies route handler .go files into the gateway build context. +// Each handler is a Go template rendered with path and options, and self-registers +// via init() calling gateway.RegisterRoute() — same pattern as custom middleware. +func CopyRouteHandlers(projectDir, outDir string, routes []RouteRef, opts map[string]any, publicURL string) error { + if len(routes) == 0 { + return nil + } + + destDir := filepath.Join(outDir, "gateway-src", "core", "gateway", "middlewares", "custom") + if err := os.MkdirAll(destDir, 0755); err != nil { + return fmt.Errorf("create route handler dest dir: %w", err) + } + + // Resolve ${VAR} references in options to actual env values + resolved := resolveEnvVars(opts) + + for _, route := range routes { + var srcPath string + if filepath.IsAbs(route.Handler) { + srcPath = route.Handler + } else { + srcPath = filepath.Join(projectDir, route.Handler) + } + content, err := os.ReadFile(srcPath) + if err != nil { + return fmt.Errorf("read route handler %s: %w", route.Handler, err) + } + + // Template data includes options, the namespaced path, and public_url + data := map[string]any{ + "options": resolved, + "path": route.Path, + "public_url": publicURL, + } + + // Template-render the handler file + rendered, err := renderRouteHandler(srcPath, string(content), data) + if err != nil { + return fmt.Errorf("render route handler %s: %w", route.Handler, err) + } + + filename := fmt.Sprintf("route_%s_%s.go", sanitizeFilename(route.PluginName), sanitizeFilename(filepath.Base(route.Handler))) + destFile := filepath.Join(destDir, filename) + if err := os.WriteFile(destFile, []byte(rendered), 0644); err != nil { + return fmt.Errorf("write route handler %s: %w", destFile, err) + } + } + + return nil +} + +// renderRouteHandler executes Go templates in route handler source code. +func renderRouteHandler(name, content string, data map[string]any) (string, error) { + if !strings.Contains(content, "{{") { + return content, nil + } + + funcMap := template.FuncMap{ + "toJSON": func(v any) (string, error) { + b, err := json.Marshal(v) + if err != nil { + return "", fmt.Errorf("toJSON: %w", err) + } + return string(b), nil + }, + } + + tmpl, err := template.New(name).Funcs(funcMap).Parse(content) + if err != nil { + return "", fmt.Errorf("parse template: %w", err) + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, data); err != nil { + return "", fmt.Errorf("execute template: %w", err) + } + return buf.String(), nil +} + +// normalizePath ensures a path starts with / and has no trailing slash. +func normalizePath(path string) string { + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + return strings.TrimRight(path, "/") +} + diff --git a/internal/generate/v1/routes_test.go b/internal/generate/v1/routes_test.go new file mode 100644 index 0000000..55f26c6 --- /dev/null +++ b/internal/generate/v1/routes_test.go @@ -0,0 +1,178 @@ +package v1 + +import ( + "os" + "path/filepath" + "testing" + "testing/fstest" + + "github.com/donbader/agent-sandbox/internal/plugin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCollectPluginRoutes_Namespacing(t *testing.T) { + tmpDir := t.TempDir() + handlerFile := filepath.Join(tmpDir, "handler.go") + require.NoError(t, os.WriteFile(handlerFile, []byte("package custom"), 0644)) + + g := &Generator{projectDir: tmpDir} + resolved := map[string]*resolvedPlugin{ + "@builtin/mcp-oauth": { + def: &plugin.PluginDef{Name: "mcp-oauth", BaseDir: tmpDir}, + rendered: &plugin.Contributions{ + Gateway: plugin.GatewayContrib{ + Routes: []plugin.RouteEntry{ + {Path: "/callback", Handler: "handler.go"}, + }, + }, + }, + }, + } + + routes, err := g.collectPluginRoutes(resolved, tmpDir) + require.NoError(t, err) + require.Len(t, routes, 1) + assert.Equal(t, "/plugins/mcp-oauth/callback", routes[0].Path) + assert.Equal(t, "mcp-oauth", routes[0].PluginName) +} + +func TestCollectPluginRoutes_ConflictDetection(t *testing.T) { + tmpDir := t.TempDir() + handlerFile := filepath.Join(tmpDir, "handler.go") + require.NoError(t, os.WriteFile(handlerFile, []byte("package custom"), 0644)) + + g := &Generator{projectDir: tmpDir} + // Two plugins with the same name would produce the same namespace path + // In practice this can't happen (plugin names are unique), but test the mechanism + resolved := map[string]*resolvedPlugin{ + "@builtin/my-plugin": { + def: &plugin.PluginDef{Name: "my-plugin", BaseDir: tmpDir}, + rendered: &plugin.Contributions{ + Gateway: plugin.GatewayContrib{ + Routes: []plugin.RouteEntry{ + {Path: "/webhook", Handler: "handler.go"}, + {Path: "/webhook", Handler: "handler.go"}, // duplicate + }, + }, + }, + }, + } + + _, err := g.collectPluginRoutes(resolved, tmpDir) + require.Error(t, err) + assert.Contains(t, err.Error(), "route conflict") +} + +func TestCollectPluginRoutes_MissingPath(t *testing.T) { + tmpDir := t.TempDir() + g := &Generator{projectDir: tmpDir} + resolved := map[string]*resolvedPlugin{ + "@builtin/bad": { + def: &plugin.PluginDef{Name: "bad", BaseDir: tmpDir}, + rendered: &plugin.Contributions{ + Gateway: plugin.GatewayContrib{ + Routes: []plugin.RouteEntry{ + {Path: "", Handler: "handler.go"}, + }, + }, + }, + }, + } + + _, err := g.collectPluginRoutes(resolved, tmpDir) + require.Error(t, err) + assert.Contains(t, err.Error(), "must have both path and handler") +} + +func TestCollectPluginRoutes_BundledFS(t *testing.T) { + tmpDir := t.TempDir() + buildDir := filepath.Join(tmpDir, ".build") + require.NoError(t, os.MkdirAll(buildDir, 0755)) + + bundledFS := fstest.MapFS{ + "my-plugin/handlers/callback.go": &fstest.MapFile{Data: []byte("package custom")}, + } + + g := &Generator{projectDir: tmpDir, bundledFS: bundledFS} + resolved := map[string]*resolvedPlugin{ + "@builtin/my-plugin": { + def: &plugin.PluginDef{Name: "my-plugin"}, // no BaseDir = bundled + rendered: &plugin.Contributions{ + Gateway: plugin.GatewayContrib{ + Routes: []plugin.RouteEntry{ + {Path: "/callback", Handler: "./handlers/callback.go"}, + }, + }, + }, + }, + } + + routes, err := g.collectPluginRoutes(resolved, buildDir) + require.NoError(t, err) + require.Len(t, routes, 1) + assert.Equal(t, "/plugins/my-plugin/callback", routes[0].Path) + + // Verify handler was extracted + _, err = os.Stat(routes[0].Handler) + assert.NoError(t, err) +} + +func TestCopyRouteHandlers(t *testing.T) { + tmpDir := t.TempDir() + outDir := filepath.Join(tmpDir, ".build") + require.NoError(t, os.MkdirAll(outDir, 0755)) + + // Create a handler file with a template variable + handlerContent := `package custom + +import "github.com/donbader/agent-sandbox/core/sdk/gateway" + +func init() { + gateway.RegisterRoute(gateway.RouteDef{ + Path: "{{ .path }}", + }) +} +` + handlerFile := filepath.Join(tmpDir, "handler.go") + require.NoError(t, os.WriteFile(handlerFile, []byte(handlerContent), 0644)) + + routes := []RouteRef{ + { + Path: "/plugins/mcp-oauth/callback", + Handler: handlerFile, + PluginName: "mcp-oauth", + }, + } + + err := CopyRouteHandlers(tmpDir, outDir, routes, map[string]any{}, "https://gateway.example.com") + require.NoError(t, err) + + // Verify the rendered handler exists + destDir := filepath.Join(outDir, "gateway-src", "core", "gateway", "middlewares", "custom") + entries, err := os.ReadDir(destDir) + require.NoError(t, err) + require.Len(t, entries, 1) + + // Verify template was rendered + content, err := os.ReadFile(filepath.Join(destDir, entries[0].Name())) + require.NoError(t, err) + assert.Contains(t, string(content), `/plugins/mcp-oauth/callback`) + assert.NotContains(t, string(content), "{{ .path }}") +} + +func TestNormalizePath(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"/callback", "/callback"}, + {"callback", "/callback"}, + {"/callback/", "/callback"}, + {"/foo/bar/", "/foo/bar"}, + } + for _, tt := range tests { + got := normalizePath(tt.input) + assert.Equal(t, tt.want, got, "normalizePath(%q)", tt.input) + } +} diff --git a/internal/plugin/merge.go b/internal/plugin/merge.go index c5c4c01..7a9b574 100644 --- a/internal/plugin/merge.go +++ b/internal/plugin/merge.go @@ -16,6 +16,7 @@ func MergeContributions(contribs ...*Contributions) *Contributions { merged.Runtime.Volumes = append(merged.Runtime.Volumes, c.Runtime.Volumes...) merged.Gateway.Services = append(merged.Gateway.Services, c.Gateway.Services...) merged.Gateway.Volumes = append(merged.Gateway.Volumes, c.Gateway.Volumes...) + merged.Gateway.Routes = append(merged.Gateway.Routes, c.Gateway.Routes...) for name, svc := range c.Sidecar.Services { merged.Sidecar.Services[name] = svc } diff --git a/internal/plugin/types.go b/internal/plugin/types.go index da30622..ca6453c 100644 --- a/internal/plugin/types.go +++ b/internal/plugin/types.go @@ -42,6 +42,14 @@ type RuntimeContrib struct { type GatewayContrib struct { Services []GatewayService `yaml:"services"` Volumes []string `yaml:"volumes"` + Routes []RouteEntry `yaml:"routes"` +} + +// RouteEntry declares an HTTP route handler contributed by a plugin. +// The path is relative to the plugin's namespace (/plugins/{plugin-name}/...). +type RouteEntry struct { + Path string `yaml:"path"` // relative path (e.g. "/callback") + Handler string `yaml:"handler"` // path to handler .go file } type GatewayService struct {