-
Notifications
You must be signed in to change notification settings - Fork 1
feat: plugin gateway route handlers with auto-namespacing #126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f711c15
63840b9
3cbbeda
3bbdae5
16fb635
acd5ac0
96c2e1a
699dcdd
07b53af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,45 +1,84 @@ | ||
| # mcp-oauth | ||
|
|
||
| Provides OAuth token storage for MCP (Model Context Protocol) providers via a shared volume between the gateway and agent. | ||
| Provides full OAuth lifecycle for MCP (Model Context Protocol) providers: automatic token injection, refresh, and browser-based authorization via gateway callback. | ||
|
|
||
| ## How It Works | ||
|
|
||
| Declares a named volume (`oauth-tokens`) mounted into both the gateway and agent containers at a configurable path. MCP providers that require OAuth can store and read token files from this shared location. | ||
|
|
||
| > **Note:** This plugin currently only handles the volume declaration. You must manually declare gateway services for each MCP provider endpoint in your `agent.yaml`. Full automation (dynamic service entries from provider URLs) is planned. | ||
| 1. **Middleware** intercepts requests to configured domains. If a valid token exists, injects `Authorization: Bearer <token>`. If no token exists, returns 401 with an `authorize_url` for the user to click. | ||
| 2. **Callback handler** at `/plugins/mcp-oauth/callback` receives the OAuth authorization code, exchanges it for tokens, and writes the token file to the shared volume. | ||
| 3. **Shared volume** (`oauth-tokens`) is mounted into both gateway and agent containers so the MCP client can read tokens written by the gateway. | ||
|
|
||
| ## Usage | ||
|
|
||
| ```yaml | ||
| # agent.yaml | ||
| gateway: | ||
| public_url: "https://gateway.myagent.example.com" | ||
| services: | ||
| - url: https://mcp.notion.com | ||
|
|
||
| installations: | ||
| - plugin: "@builtin/mcp-oauth" | ||
| options: | ||
| providers: | ||
| # Dynamic mode: just provide mcp_url — credentials auto-discovered | ||
| notion: | ||
| mcp_url: https://mcp.notion.com/mcp | ||
| token_dir: "/data/oauth-tokens" | ||
|
|
||
| gateway: | ||
| services: | ||
| - url: https://mcp.notion.com | ||
| headers: | ||
| Authorization: Bearer ${NOTION_TOKEN} | ||
| ``` | ||
|
|
||
| ```bash | ||
| # .env | ||
| NOTION_TOKEN=ntn_xxxx | ||
| # Static mode: provide all OAuth details manually | ||
| custom-provider: | ||
| mcp_url: https://custom.example.com/mcp | ||
| authorize_endpoint: https://custom.example.com/oauth/authorize | ||
| token_endpoint: https://custom.example.com/oauth/token | ||
| client_id: "your-client-id" | ||
| client_secret: "your-client-secret" | ||
| scopes: "read_content" | ||
| token_dir: "/data/oauth-tokens" | ||
| ``` | ||
|
|
||
| ## Options | ||
|
|
||
| | Option | Type | Required | Default | Description | | ||
| |--------|------|----------|---------|-------------| | ||
| | `providers` | object | yes | — | Map of provider name to config. Each provider needs at least `mcp_url`. | | ||
| | `providers` | object | yes | — | Map of provider name to OAuth config | | ||
| | `token_dir` | string | no | `/data/oauth-tokens` | Directory for OAuth token files | | ||
|
|
||
| ### Provider Config | ||
|
|
||
| Each provider entry supports two modes: | ||
|
|
||
| **Dynamic mode** (recommended for MCP servers that support RFC 7591): | ||
|
|
||
| | Field | Required | Description | | ||
| |-------|----------|-------------| | ||
| | `mcp_url` | yes | MCP server endpoint — metadata + registration auto-discovered | | ||
|
|
||
| **Static mode** (for providers without dynamic registration): | ||
|
|
||
| | Field | Required | Description | | ||
| |-------|----------|-------------| | ||
| | `mcp_url` | yes | MCP server endpoint | | ||
| | `authorize_endpoint` | yes | OAuth authorize URL | | ||
| | `token_endpoint` | yes | OAuth token exchange URL | | ||
| | `client_id` | yes | OAuth client ID | | ||
| | `client_secret` | no | OAuth client secret | | ||
| | `scopes` | no | Space-separated scopes | | ||
|
|
||
| Mode is auto-detected: if `client_id` is absent, dynamic mode is used. | ||
|
|
||
| ## What It Contributes | ||
|
|
||
| - **Gateway:** Shared volume `oauth-tokens` mounted at `token_dir` | ||
| - **Agent:** Same volume accessible for MCP client token reads | ||
| - **Gateway middleware:** Token injection + 401 with authorize URL when unauthenticated | ||
| - **Gateway route:** `/plugins/mcp-oauth/callback` — OAuth code exchange handler | ||
| - **Gateway volume:** Shared `oauth-tokens` volume at `token_dir` | ||
|
|
||
| ## OAuth Flow | ||
|
|
||
| ``` | ||
| 1. Agent MCP client → request to notion domain | ||
| 2. Gateway middleware: no token file → returns 401 + authorize_url | ||
| 3. User clicks authorize_url → Notion login page | ||
| 4. Notion redirects → https://gateway.example.com/plugins/mcp-oauth/callback?code=X&state=notion | ||
| 5. Gateway callback handler: exchanges code → writes /data/oauth-tokens/notion.json | ||
| 6. Next request → middleware reads token → injects Bearer header → proxied to Notion | ||
| ``` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,208 @@ | ||
| package custom | ||
|
|
||
| import ( | ||
| "context" | ||
| "crypto/hmac" | ||
| "crypto/sha256" | ||
| "encoding/hex" | ||
| "encoding/json" | ||
| "fmt" | ||
| "html" | ||
| "io" | ||
| "log/slog" | ||
| "net" | ||
| "net/http" | ||
| "net/url" | ||
| "os" | ||
| "strings" | ||
| "time" | ||
|
|
||
| "github.com/donbader/agent-sandbox/core/sdk/gateway" | ||
| ) | ||
|
|
||
| type oauthCallbackConfig struct { | ||
| TokenEndpoint string | ||
| ClientID string | ||
| ClientSecret string | ||
| } | ||
|
|
||
| var ( | ||
| oauthCallbackProviders = map[string]oauthCallbackConfig{} | ||
| oauthCallbackTokenDir string | ||
| oauthCallbackHMACKey []byte | ||
| ) | ||
|
|
||
| func init() { | ||
| oauthCallbackTokenDir = "{{ .options.token_dir }}" | ||
| providersJSON := `{{ toJSON .options.providers }}` | ||
|
|
||
| // Derive HMAC key from providers config (same derivation as middleware) | ||
| h := sha256.Sum256([]byte(providersJSON)) | ||
| oauthCallbackHMACKey = h[:] | ||
|
|
||
| var providers map[string]map[string]any | ||
| if err := json.Unmarshal([]byte(providersJSON), &providers); err != nil { | ||
| slog.Error("oauth-callback: failed to parse providers", "error", err) | ||
| } else { | ||
| for name, cfg := range providers { | ||
| p := oauthCallbackConfig{} | ||
| if v, ok := cfg["token_endpoint"].(string); ok { | ||
| p.TokenEndpoint = v | ||
| } | ||
| if v, ok := cfg["client_id"].(string); ok { | ||
| p.ClientID = v | ||
| } | ||
| if v, ok := cfg["client_secret"].(string); ok { | ||
| p.ClientSecret = v | ||
| } | ||
| oauthCallbackProviders[name] = p | ||
| } | ||
| } | ||
| gateway.RegisterRoute(gateway.RouteDef{ | ||
| Path: "{{ .path }}", | ||
| Handler: handleOAuthCallback, | ||
| }) | ||
| } | ||
|
|
||
| func handleOAuthCallback(w http.ResponseWriter, r *http.Request) { | ||
| code := r.URL.Query().Get("code") | ||
| state := r.URL.Query().Get("state") | ||
| if code == "" { | ||
| http.Error(w, "missing code parameter", http.StatusBadRequest) | ||
| return | ||
| } | ||
| if state == "" { | ||
| http.Error(w, "missing state parameter", http.StatusBadRequest) | ||
| return | ||
| } | ||
| // Validate HMAC state: format is "hmac_sig:provider_name" | ||
| parts := strings.SplitN(state, ":", 2) | ||
| if len(parts) != 2 { | ||
| http.Error(w, "invalid state format", http.StatusForbidden) | ||
| return | ||
| } | ||
| sig, providerName := parts[0], parts[1] | ||
| mac := hmac.New(sha256.New, oauthCallbackHMACKey) | ||
| mac.Write([]byte(providerName)) | ||
| expectedSig := hex.EncodeToString(mac.Sum(nil))[:16] | ||
| if !hmac.Equal([]byte(sig), []byte(expectedSig)) { | ||
| http.Error(w, "invalid state signature", http.StatusForbidden) | ||
| return | ||
| } | ||
| provider, ok := oauthCallbackProviders[providerName] | ||
| if !ok { | ||
| http.Error(w, "unknown provider", http.StatusBadRequest) | ||
| return | ||
| } | ||
| if provider.TokenEndpoint == "" { | ||
| http.Error(w, "provider not configured", http.StatusInternalServerError) | ||
| return | ||
| } | ||
| redirectURI := "{{ .public_url }}{{ .path }}" | ||
| token, err := exchangeCodeForToken(provider, code, redirectURI) | ||
| if err != nil { | ||
| slog.Error("oauth-callback: token exchange failed", "error", err) | ||
| http.Error(w, "token exchange failed", http.StatusInternalServerError) | ||
| return | ||
| } | ||
| tokenFile := oauthCallbackTokenDir + "/" + providerName + ".json" | ||
| if err := writeOAuthToken(tokenFile, token, provider); err != nil { | ||
| slog.Error("oauth-callback: failed to save token", "error", err) | ||
| http.Error(w, "failed to save token", http.StatusInternalServerError) | ||
| return | ||
| } | ||
| gateway.RegisterSecret(token.AccessToken) | ||
| w.Header().Set("Content-Type", "text/html; charset=utf-8") | ||
| _, _ = fmt.Fprintf(w, `<!DOCTYPE html><html><body> | ||
| <h1>Authorization successful</h1> | ||
| <p>Provider <strong>%s</strong> connected. You can close this tab.</p> | ||
| </body></html>`, html.EscapeString(providerName)) | ||
| } | ||
|
|
||
| type oauthTokenExchangeResponse struct { | ||
| AccessToken string `json:"access_token"` | ||
| RefreshToken string `json:"refresh_token"` | ||
| ExpiresIn int64 `json:"expires_in"` | ||
| } | ||
|
|
||
| func exchangeCodeForToken(provider oauthCallbackConfig, code, redirectURI string) (*oauthTokenExchangeResponse, error) { | ||
| u, err := url.Parse(provider.TokenEndpoint) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("invalid token_endpoint: %w", err) | ||
| } | ||
| if u.Scheme != "https" { | ||
| return nil, fmt.Errorf("token_endpoint must use https, got %q", u.Scheme) | ||
| } | ||
| params := url.Values{ | ||
| "grant_type": {"authorization_code"}, | ||
| "code": {code}, | ||
| "client_id": {provider.ClientID}, | ||
| "redirect_uri": {redirectURI}, | ||
| } | ||
| if provider.ClientSecret != "" { | ||
| params.Set("client_secret", provider.ClientSecret) | ||
| } | ||
| client := &http.Client{Timeout: 30 * time.Second, Transport: cbSSRFSafe()} | ||
| resp, err := client.Post(provider.TokenEndpoint, "application/x-www-form-urlencoded", strings.NewReader(params.Encode())) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("request failed: %w", err) | ||
| } | ||
| defer func() { _ = resp.Body.Close() }() | ||
| body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("reading response: %w", err) | ||
| } | ||
| if resp.StatusCode != http.StatusOK { | ||
| return nil, fmt.Errorf("token endpoint returned %d", resp.StatusCode) | ||
| } | ||
| var tr oauthTokenExchangeResponse | ||
| if err := json.Unmarshal(body, &tr); err != nil { | ||
| return nil, fmt.Errorf("parsing response: %w", err) | ||
| } | ||
| return &tr, nil | ||
| } | ||
|
|
||
| func writeOAuthToken(path string, token *oauthTokenExchangeResponse, _ oauthCallbackConfig) error { | ||
| expiresIn := token.ExpiresIn | ||
| if expiresIn == 0 { | ||
| expiresIn = 3600 | ||
| } | ||
| stored := map[string]any{ | ||
| "access_token": token.AccessToken, | ||
| "expires_at": time.Now().Unix() + expiresIn, | ||
| } | ||
| if token.RefreshToken != "" { | ||
| stored["refresh_token"] = token.RefreshToken | ||
| } | ||
| data, err := json.MarshalIndent(stored, "", " ") | ||
| if err != nil { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ⚪ [unknown] [security] if provider.ClientSecret != "" {
stored["client_secret"] = provider.ClientSecret
}The OAuth Storing secrets in token files expands the blast radius: a compromised agent container, an accidental log dump of the token file, or volume backup exposure all leak the secret alongside the token. |
||
| return err | ||
| } | ||
| tmp := path + ".tmp" | ||
| if err := os.WriteFile(tmp, data, 0600); err != nil { | ||
| return err | ||
| } | ||
| return os.Rename(tmp, path) | ||
| } | ||
|
|
||
| func cbSSRFSafe() *http.Transport { | ||
| return &http.Transport{ | ||
| DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { | ||
| host, port, err := net.SplitHostPort(addr) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("invalid address %q: %w", addr, err) | ||
| } | ||
| ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("DNS lookup failed for %q: %w", host, err) | ||
| } | ||
| for _, ip := range ips { | ||
| if ip.IP.IsLoopback() || ip.IP.IsPrivate() || ip.IP.IsLinkLocalUnicast() || ip.IP.IsLinkLocalMulticast() { | ||
| return nil, fmt.Errorf("refusing to connect to private IP %s", ip.IP) | ||
| } | ||
| } | ||
| dialer := &net.Dialer{Timeout: 10 * time.Second} | ||
| return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port)) | ||
| }, | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
⚪ [unknown] [security] CSRF:
stateused as provider name instead of a nonceThe
stateparameter is validated only by checking whether it matches a known provider name. It carries no CSRF protection. RFC 6749 §10.12 requiresstateto be an unguessable random value tied to the session that initiated the authorization request. As-is, any party that obtains a valid authorizationcode(e.g., via referrer leakage, log exposure, or interception) can craft:and the handler will exchange the code and write the resulting token without any verification that this callback corresponds to a flow this gateway started.
Fix: generate a random nonce on the authorize redirect, store it (e.g., in a short-lived file or in-memory map keyed by nonce), and verify on callback. Encode both nonce and provider name in
state(e.g.,notion:hex-nonce).