Skip to content
11 changes: 10 additions & 1 deletion core/gateway/cmd/gateway/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,23 @@ func main() {
}()
}

// Health endpoint
// Health + route handler endpoint
healthAddr := ":8080"
go func() {
mux := http.NewServeMux()
mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
// Serve plugin-registered routes (e.g. /plugins/mcp-oauth/callback)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
handler := gateway.MatchRoute(r.URL.Path)
if handler != nil {
handler(w, r)
return
}
http.NotFound(w, r)
})
if err := http.ListenAndServe(healthAddr, mux); err != nil {
slog.Error("health server error", "error", err)
}
Expand Down
75 changes: 57 additions & 18 deletions core/plugins/mcp-oauth/README.md
Original file line number Diff line number Diff line change
@@ -1,45 +1,84 @@
# mcp-oauth

Provides OAuth token storage for MCP (Model Context Protocol) providers via a shared volume between the gateway and agent.
Provides full OAuth lifecycle for MCP (Model Context Protocol) providers: automatic token injection, refresh, and browser-based authorization via gateway callback.

## How It Works

Declares a named volume (`oauth-tokens`) mounted into both the gateway and agent containers at a configurable path. MCP providers that require OAuth can store and read token files from this shared location.

> **Note:** This plugin currently only handles the volume declaration. You must manually declare gateway services for each MCP provider endpoint in your `agent.yaml`. Full automation (dynamic service entries from provider URLs) is planned.
1. **Middleware** intercepts requests to configured domains. If a valid token exists, injects `Authorization: Bearer <token>`. If no token exists, returns 401 with an `authorize_url` for the user to click.
2. **Callback handler** at `/plugins/mcp-oauth/callback` receives the OAuth authorization code, exchanges it for tokens, and writes the token file to the shared volume.
3. **Shared volume** (`oauth-tokens`) is mounted into both gateway and agent containers so the MCP client can read tokens written by the gateway.

## Usage

```yaml
# agent.yaml
gateway:
public_url: "https://gateway.myagent.example.com"
services:
- url: https://mcp.notion.com

installations:
- plugin: "@builtin/mcp-oauth"
options:
providers:
# Dynamic mode: just provide mcp_url — credentials auto-discovered
notion:
mcp_url: https://mcp.notion.com/mcp
token_dir: "/data/oauth-tokens"

gateway:
services:
- url: https://mcp.notion.com
headers:
Authorization: Bearer ${NOTION_TOKEN}
```

```bash
# .env
NOTION_TOKEN=ntn_xxxx
# Static mode: provide all OAuth details manually
custom-provider:
mcp_url: https://custom.example.com/mcp
authorize_endpoint: https://custom.example.com/oauth/authorize
token_endpoint: https://custom.example.com/oauth/token
client_id: "your-client-id"
client_secret: "your-client-secret"
scopes: "read_content"
token_dir: "/data/oauth-tokens"
```

## Options

| Option | Type | Required | Default | Description |
|--------|------|----------|---------|-------------|
| `providers` | object | yes | — | Map of provider name to config. Each provider needs at least `mcp_url`. |
| `providers` | object | yes | — | Map of provider name to OAuth config |
| `token_dir` | string | no | `/data/oauth-tokens` | Directory for OAuth token files |

### Provider Config

Each provider entry supports two modes:

**Dynamic mode** (recommended for MCP servers that support RFC 7591):

| Field | Required | Description |
|-------|----------|-------------|
| `mcp_url` | yes | MCP server endpoint — metadata + registration auto-discovered |

**Static mode** (for providers without dynamic registration):

| Field | Required | Description |
|-------|----------|-------------|
| `mcp_url` | yes | MCP server endpoint |
| `authorize_endpoint` | yes | OAuth authorize URL |
| `token_endpoint` | yes | OAuth token exchange URL |
| `client_id` | yes | OAuth client ID |
| `client_secret` | no | OAuth client secret |
| `scopes` | no | Space-separated scopes |

Mode is auto-detected: if `client_id` is absent, dynamic mode is used.

## What It Contributes

- **Gateway:** Shared volume `oauth-tokens` mounted at `token_dir`
- **Agent:** Same volume accessible for MCP client token reads
- **Gateway middleware:** Token injection + 401 with authorize URL when unauthenticated
- **Gateway route:** `/plugins/mcp-oauth/callback` — OAuth code exchange handler
- **Gateway volume:** Shared `oauth-tokens` volume at `token_dir`

## OAuth Flow

```
1. Agent MCP client → request to notion domain
2. Gateway middleware: no token file → returns 401 + authorize_url
3. User clicks authorize_url → Notion login page
4. Notion redirects → https://gateway.example.com/plugins/mcp-oauth/callback?code=X&state=notion
5. Gateway callback handler: exchanges code → writes /data/oauth-tokens/notion.json
6. Next request → middleware reads token → injects Bearer header → proxied to Notion
```
208 changes: 208 additions & 0 deletions core/plugins/mcp-oauth/handlers/callback.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package custom

import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"html"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"

"github.com/donbader/agent-sandbox/core/sdk/gateway"
)

type oauthCallbackConfig struct {
TokenEndpoint string
ClientID string
ClientSecret string
}

var (
oauthCallbackProviders = map[string]oauthCallbackConfig{}
oauthCallbackTokenDir string
oauthCallbackHMACKey []byte
)

func init() {
oauthCallbackTokenDir = "{{ .options.token_dir }}"
providersJSON := `{{ toJSON .options.providers }}`

// Derive HMAC key from providers config (same derivation as middleware)
h := sha256.Sum256([]byte(providersJSON))
oauthCallbackHMACKey = h[:]

var providers map[string]map[string]any
if err := json.Unmarshal([]byte(providersJSON), &providers); err != nil {
slog.Error("oauth-callback: failed to parse providers", "error", err)
} else {
for name, cfg := range providers {
p := oauthCallbackConfig{}
if v, ok := cfg["token_endpoint"].(string); ok {
p.TokenEndpoint = v
}
if v, ok := cfg["client_id"].(string); ok {
p.ClientID = v
}
if v, ok := cfg["client_secret"].(string); ok {
p.ClientSecret = v
}
oauthCallbackProviders[name] = p
}
}
gateway.RegisterRoute(gateway.RouteDef{
Path: "{{ .path }}",
Handler: handleOAuthCallback,
})
}

func handleOAuthCallback(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[unknown] [security] CSRF: state used as provider name instead of a nonce

The state parameter is validated only by checking whether it matches a known provider name. It carries no CSRF protection. RFC 6749 §10.12 requires state to be an unguessable random value tied to the session that initiated the authorization request. As-is, any party that obtains a valid authorization code (e.g., via referrer leakage, log exposure, or interception) can craft:

GET /plugins/mcp-oauth/callback?code=STOLEN&state=notion

and the handler will exchange the code and write the resulting token without any verification that this callback corresponds to a flow this gateway started.

Fix: generate a random nonce on the authorize redirect, store it (e.g., in a short-lived file or in-memory map keyed by nonce), and verify on callback. Encode both nonce and provider name in state (e.g., notion:hex-nonce).

state := r.URL.Query().Get("state")
if code == "" {
http.Error(w, "missing code parameter", http.StatusBadRequest)
return
}
if state == "" {
http.Error(w, "missing state parameter", http.StatusBadRequest)
return
}
// Validate HMAC state: format is "hmac_sig:provider_name"
parts := strings.SplitN(state, ":", 2)
if len(parts) != 2 {
http.Error(w, "invalid state format", http.StatusForbidden)
return
}
sig, providerName := parts[0], parts[1]
mac := hmac.New(sha256.New, oauthCallbackHMACKey)
mac.Write([]byte(providerName))
expectedSig := hex.EncodeToString(mac.Sum(nil))[:16]
if !hmac.Equal([]byte(sig), []byte(expectedSig)) {
http.Error(w, "invalid state signature", http.StatusForbidden)
return
}
provider, ok := oauthCallbackProviders[providerName]
if !ok {
http.Error(w, "unknown provider", http.StatusBadRequest)
return
}
if provider.TokenEndpoint == "" {
http.Error(w, "provider not configured", http.StatusInternalServerError)
return
}
redirectURI := "{{ .public_url }}{{ .path }}"
token, err := exchangeCodeForToken(provider, code, redirectURI)
if err != nil {
slog.Error("oauth-callback: token exchange failed", "error", err)
http.Error(w, "token exchange failed", http.StatusInternalServerError)
return
}
tokenFile := oauthCallbackTokenDir + "/" + providerName + ".json"
if err := writeOAuthToken(tokenFile, token, provider); err != nil {
slog.Error("oauth-callback: failed to save token", "error", err)
http.Error(w, "failed to save token", http.StatusInternalServerError)
return
}
gateway.RegisterSecret(token.AccessToken)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
_, _ = fmt.Fprintf(w, `<!DOCTYPE html><html><body>
<h1>Authorization successful</h1>
<p>Provider <strong>%s</strong> connected. You can close this tab.</p>
</body></html>`, html.EscapeString(providerName))
}

type oauthTokenExchangeResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
}

func exchangeCodeForToken(provider oauthCallbackConfig, code, redirectURI string) (*oauthTokenExchangeResponse, error) {
u, err := url.Parse(provider.TokenEndpoint)
if err != nil {
return nil, fmt.Errorf("invalid token_endpoint: %w", err)
}
if u.Scheme != "https" {
return nil, fmt.Errorf("token_endpoint must use https, got %q", u.Scheme)
}
params := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"client_id": {provider.ClientID},
"redirect_uri": {redirectURI},
}
if provider.ClientSecret != "" {
params.Set("client_secret", provider.ClientSecret)
}
client := &http.Client{Timeout: 30 * time.Second, Transport: cbSSRFSafe()}
resp, err := client.Post(provider.TokenEndpoint, "application/x-www-form-urlencoded", strings.NewReader(params.Encode()))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token endpoint returned %d", resp.StatusCode)
}
var tr oauthTokenExchangeResponse
if err := json.Unmarshal(body, &tr); err != nil {
return nil, fmt.Errorf("parsing response: %w", err)
}
return &tr, nil
}

func writeOAuthToken(path string, token *oauthTokenExchangeResponse, _ oauthCallbackConfig) error {
expiresIn := token.ExpiresIn
if expiresIn == 0 {
expiresIn = 3600
}
stored := map[string]any{
"access_token": token.AccessToken,
"expires_at": time.Now().Unix() + expiresIn,
}
if token.RefreshToken != "" {
stored["refresh_token"] = token.RefreshToken
}
data, err := json.MarshalIndent(stored, "", " ")
if err != nil {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[unknown] [security] client_secret written to the shared token file

if provider.ClientSecret != "" {
    stored["client_secret"] = provider.ClientSecret
}

The OAuth client_secret is written into the token JSON file on the shared oauth-tokens volume, which is mounted into both the gateway and agent containers. The secret is already available at runtime via the baked-in template; there's no functional reason to persist it in the token file. The middleware reads client_secret from the storedToken struct for refresh, but it could read from the in-memory config instead.

Storing secrets in token files expands the blast radius: a compromised agent container, an accidental log dump of the token file, or volume backup exposure all leak the secret alongside the token.

return err
}
tmp := path + ".tmp"
if err := os.WriteFile(tmp, data, 0600); err != nil {
return err
}
return os.Rename(tmp, path)
}

func cbSSRFSafe() *http.Transport {
return &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("invalid address %q: %w", addr, err)
}
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, fmt.Errorf("DNS lookup failed for %q: %w", host, err)
}
for _, ip := range ips {
if ip.IP.IsLoopback() || ip.IP.IsPrivate() || ip.IP.IsLinkLocalUnicast() || ip.IP.IsLinkLocalMulticast() {
return nil, fmt.Errorf("refusing to connect to private IP %s", ip.IP)
}
}
dialer := &net.Dialer{Timeout: 10 * time.Second}
return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port))
},
}
}
Loading
Loading