Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions core/gateway/cmd/gateway/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/donbader/agent-sandbox/core/gateway/internal/mitm"
"github.com/donbader/agent-sandbox/core/gateway/internal/proxy"
"github.com/donbader/agent-sandbox/core/gateway/internal/redact"
"github.com/donbader/agent-sandbox/core/sdk/gateway"

// Custom middleware compilation target — user .go files are copied here at generate-time.
_ "github.com/donbader/agent-sandbox/core/gateway/middlewares/custom"
Expand Down Expand Up @@ -48,8 +49,20 @@ func main() {
os.Exit(1)
}

// Collect secret values from rewriter env vars for value-based redaction.
secrets := collectSecrets(cfg.Rewriters)
// Build rewriters early so we can collect secrets from them (e.g. OAuth tokens)
// before constructing the redacting logger.
rewriters := buildRewriters(cfg.Rewriters)

// Collect secret values for value-based log redaction from two sources:
// 1. Rewriters that implement SecretProvider (auth-header env vars, OAuth tokens).
var secrets []string
for _, rw := range rewriters {
if sp, ok := rw.(mitm.SecretProvider); ok {
secrets = append(secrets, sp.Secrets()...)
}
}
// 2. Secrets declared by custom middleware via gateway.RegisterSecret().
secrets = append(secrets, gateway.Secrets()...)

jsonHandler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: level,
Expand Down Expand Up @@ -77,9 +90,6 @@ func main() {
// Start TCP proxy
p := proxy.New(cfg)

// Build rewriters from config (shared between MITM and HTTP handlers)
rewriters := buildRewriters(cfg.Rewriters)

// Generate CA and register MITM handler if MITM domains are configured
if len(cfg.MITMDomains) > 0 {
slog.Info("generating CA keypair for MITM")
Expand Down Expand Up @@ -176,19 +186,3 @@ func buildRewriters(cfgs []proxy.RewriterConfig) []mitm.Rewriter {
return rewriters
}

// collectSecrets reads the raw secret values from environment variables referenced
// by rewriter configs. These values are used for value-based log redaction.
func collectSecrets(cfgs []proxy.RewriterConfig) []string {
var secrets []string
for _, rc := range cfgs {
switch rc.Type {
case "auth-header":
if rc.EnvVar != "" {
if v := os.Getenv(rc.EnvVar); v != "" {
secrets = append(secrets, v)
}
}
}
}
return secrets
}
10 changes: 10 additions & 0 deletions core/gateway/internal/mitm/auth_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type AuthHeaderRewriter struct {
domains []string
header string
headerValue string // pre-computed final header value
rawSecret string // original env var value for redaction
}

// NewAuthHeaderRewriter creates a rewriter that injects a header for the given domains.
Expand All @@ -42,9 +43,18 @@ func NewAuthHeaderRewriter(domains []string, header, valueFormat, envVar string)
domains: domains,
header: header,
headerValue: headerValue,
rawSecret: value,
}, nil
}

// Secrets returns the raw secret value for log redaction.
func (r *AuthHeaderRewriter) Secrets() []string {
if r.rawSecret == "" {
return nil
}
return []string{r.rawSecret}
}

// RewriteRequest injects the configured header if the request host matches one of the
// configured domains. Returns true if the header was injected.
func (r *AuthHeaderRewriter) RewriteRequest(req *http.Request) bool {
Expand Down
51 changes: 37 additions & 14 deletions core/gateway/internal/mitm/mitm.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@ import (
"sync"
)

// SecretProvider is implemented by rewriters that hold sensitive values (e.g. tokens)
// that should be redacted from logs.
type SecretProvider interface {
Secrets() []string
}

// Handler implements proxy.RequestHandler for MITM domains.
// It terminates TLS using the sandbox CA, parses HTTP requests,
// applies rewriters, and forwards to the real destination.
type Handler struct {
domains []string
caCert tls.Certificate
certCache *CertCache
rewriters []Rewriter
domains []string
caCert tls.Certificate
certCache *CertCache
rewriters []Rewriter
transportCache sync.Map // keyed by serverName → *http.Transport
}

// Rewriter modifies HTTP requests before forwarding.
Expand Down Expand Up @@ -134,6 +141,31 @@ func (h *Handler) Handle(clientConn net.Conn, initialData []byte, serverName str
}
}

// getTransport returns a cached *http.Transport for the given serverName, creating
// one on first use. Reusing transports enables TCP/TLS connection pooling.
func (h *Handler) getTransport(serverName string) *http.Transport {
insecure := os.Getenv("GATEWAY_INSECURE_UPSTREAM") == "true"

if v, ok := h.transportCache.Load(serverName); ok {
t, _ := v.(*http.Transport)
return t
}

t := &http.Transport{
TLSClientConfig: &tls.Config{
ServerName: serverName,
InsecureSkipVerify: insecure, //nolint:gosec // test-only
},
// Disable compression so we can stream the raw response bytes.
DisableCompression: true,
}

// Store, but prefer an existing entry if a concurrent goroutine beat us.
actual, _ := h.transportCache.LoadOrStore(serverName, t)
result, _ := actual.(*http.Transport)
return result
}

// forwardRequest sends the request to the real server over TLS.
func (h *Handler) forwardRequest(req *http.Request, serverName string) (*http.Response, error) {
// Set the host header and request URI
Expand All @@ -149,17 +181,8 @@ func (h *Handler) forwardRequest(req *http.Request, serverName string) (*http.Re
req.URL.Scheme = "https"
}

transport := &http.Transport{
TLSClientConfig: &tls.Config{
ServerName: serverName,
InsecureSkipVerify: insecure, //nolint:gosec // test-only
},
// Disable compression so we can stream the raw response bytes
DisableCompression: true,
}

client := &http.Client{
Transport: transport,
Transport: h.getTransport(serverName),
// Don't follow redirects — pass them through
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
Expand Down
16 changes: 16 additions & 0 deletions core/gateway/internal/mitm/mitm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,19 @@ func TestHandler_Matches(t *testing.T) {
t.Error("expected no match for unknown.com")
}
}

func TestHandler_TransportReuse(t *testing.T) {
ca := testCA(t)
h := NewHandler([]string{"example.com"}, ca, nil)

t1 := h.getTransport("example.com")
t2 := h.getTransport("example.com")
if t1 != t2 {
t.Error("expected same transport to be reused for same host")
}

t3 := h.getTransport("other.com")
if t1 == t3 {
t.Error("expected different transport for different hosts")
}
}
11 changes: 11 additions & 0 deletions core/gateway/internal/mitm/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ func NewOAuthRewriter(domains []string, tokenFile string) (*OAuthRewriter, error
return r, nil
}

// Secrets implements SecretProvider. Returns the current cached access token
// so it can be added to the log redaction list.
func (r *OAuthRewriter) Secrets() []string {
r.mu.Lock()
defer r.mu.Unlock()
if r.cachedToken != nil && r.cachedToken.AccessToken != "" {
return []string{r.cachedToken.AccessToken}
}
return nil
}

// RewriteRequest injects a Bearer Authorization header if the request host matches
// one of the configured domains. Returns true if the header was injected.
func (r *OAuthRewriter) RewriteRequest(req *http.Request) bool {
Expand Down
24 changes: 24 additions & 0 deletions core/gateway/internal/mitm/oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,30 @@ func TestOAuthRewriter_CachesToken(t *testing.T) {
assert.Equal(t, "Bearer cached-token", req2.Header.Get("Authorization"))
}

func TestOAuthRewriter_ImplementsSecretProvider(t *testing.T) {
tokenFile := writeTestToken(t, &StoredToken{
AccessToken: "super-secret-token",
RefreshToken: strPtr("refresh"),
ExpiresAt: time.Now().Unix() + 3600,
TokenEndpoint: "https://example.com/token",
ClientID: "cid",
ClientSecret: nil,
})

rw, err := NewOAuthRewriter([]string{"mcp.notion.com"}, tokenFile)
require.NoError(t, err)

// Prime the cache so the token is available.
req := httptest.NewRequest("POST", "https://mcp.notion.com/mcp", nil)
req.Host = "mcp.notion.com"
rw.RewriteRequest(req)

sp, ok := any(rw).(SecretProvider)
require.True(t, ok, "OAuthRewriter must implement SecretProvider")
secrets := sp.Secrets()
assert.Contains(t, secrets, "super-secret-token")
}

// --- helpers ---

func writeTestToken(t *testing.T, token *StoredToken) string {
Expand Down
8 changes: 6 additions & 2 deletions core/plugins/github-pat/middlewares/github-auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ import (
)

func init() {
// The secret is baked at generate-time from plugin options.
token := "{{ .options.token }}"
if token != "" {
gateway.RegisterSecret(token)
}

gateway.RegisterMiddleware("github-basic-auth", func(ctx *gateway.MiddlewareContext) error {
// Git uses Basic auth with format: x-access-token:<PAT>
// The secret is baked at generate-time from plugin options.
token := "{{ .options.token }}"
if token == "" {
return nil
}
Expand Down
16 changes: 16 additions & 0 deletions core/sdk/gateway/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,19 @@ func Get(name string) (MiddlewareFunc, bool) {
func All() map[string]MiddlewareFunc {
return registry
}

// secrets collects values that should be redacted from logs.
var secrets []string

// RegisterSecret declares a value that should be redacted from gateway logs.
// Call this in init() alongside RegisterMiddleware for any baked-in secrets.
func RegisterSecret(value string) {
if value != "" {
secrets = append(secrets, value)
}
}

// Secrets returns all secrets registered by middleware for log redaction.
func Secrets() []string {
return secrets
}
23 changes: 23 additions & 0 deletions core/sdk/gateway/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package gateway

import "testing"

func TestRegisterSecret(t *testing.T) {
// Reset state
secrets = nil

RegisterSecret("my-secret-token")
RegisterSecret("") // empty should be ignored
RegisterSecret("another-secret")

got := Secrets()
if len(got) != 2 {
t.Fatalf("expected 2 secrets, got %d", len(got))
}
if got[0] != "my-secret-token" {
t.Errorf("expected 'my-secret-token', got %q", got[0])
}
if got[1] != "another-secret" {
t.Errorf("expected 'another-secret', got %q", got[1])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ import (
)

func init() {
realToken := "{{ .options.bot_token }}"
if realToken != "" {
gateway.RegisterSecret(realToken)
}

gateway.RegisterMiddleware("telegram-token-rewrite", func(ctx *gateway.MiddlewareContext) error {
realToken := "{{ .options.bot_token }}"
if realToken == "" {
return nil
}
Expand Down
6 changes: 3 additions & 3 deletions internal/generate/v1/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ func TestGenerateSchema(t *testing.T) {
assert.Contains(t, schema, "properties")

// Verify key properties exist
props := schema["properties"].(map[string]any)
props, _ := schema["properties"].(map[string]any)
assert.Contains(t, props, "name")
assert.Contains(t, props, "runtime")
assert.Contains(t, props, "gateway")
assert.Contains(t, props, "installations")

// Verify required fields
required := schema["required"].([]any)
required, _ := schema["required"].([]any)
assert.Contains(t, required, "name")
assert.Contains(t, required, "runtime")

// Verify nested runtime properties
runtimeProps := props["runtime"].(map[string]any)["properties"].(map[string]any)
runtimeProps, _ := props["runtime"].(map[string]any)["properties"].(map[string]any)
assert.Contains(t, runtimeProps, "image")
assert.Contains(t, runtimeProps, "extra_builds")
assert.Contains(t, runtimeProps, "entrypoint")
Expand Down
Loading