diff --git a/.gitignore b/.gitignore index 9906a27..12082f6 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,6 @@ Thumbs.db channel-manager/node_modules/ node_modules/ .worktrees/ +.ssh/ +ssh*key +bin/agent-sandbox diff --git a/cmd/agent-sandbox/main.go b/cmd/agent-sandbox/main.go index 7c2a68e..3d8a1cc 100644 --- a/cmd/agent-sandbox/main.go +++ b/cmd/agent-sandbox/main.go @@ -17,6 +17,7 @@ import ( "github.com/donbader/agent-sandbox/internal/generate" _ "github.com/donbader/agent-sandbox/internal/plugins" // register core feature plugins "github.com/donbader/agent-sandbox/internal/resolve" + crt "github.com/donbader/agent-sandbox/internal/runtime" "github.com/spf13/cobra" ) @@ -190,10 +191,11 @@ func generateAgent(dir, outDir string, cfg *config.AgentConfig, _ *config.Shared ChannelManager: hasChannelManager, SkipEnvExample: skipEnvExample, GatewaySpec: generate.GatewaySpec{ - BuildImage: "golang:1.26.4-alpine", - BinaryPath: "/gateway", - ListenPort: 8443, - DNSPort: 53, + BuildImage: "golang:1.26.4-alpine", + BinaryPath: "/gateway", + ListenPort: 8443, + HTTPListenPort: 8080, + DNSPort: 53, }, ChannelManagerSpec: generate.ChannelManagerSpec{ BuildImage: "node:22-slim", @@ -216,7 +218,7 @@ func writeFleetCompose(outDir string, agents []string) error { b.WriteString("include:\n") for _, name := range agents { - _, _ = fmt.Fprintf(&b, " - path: %s/docker-compose.yml\n", name) + _, _ = fmt.Fprintf(&b, " - %s/docker-compose.yml\n", name) } composePath := filepath.Join(outDir, "docker-compose.yml") @@ -269,22 +271,46 @@ func ensureSchemaComment(yamlPath string, schemaRelPath string) error { func composeCmd(dir *string) *cobra.Command { cmd := &cobra.Command{ Use: "compose", - Short: "Docker compose passthrough (auto-injects -f .build/docker-compose.yml)", + Short: "Container compose passthrough (auto-injects -f .build/docker-compose.yml)", DisableFlagParsing: true, RunE: func(cmd *cobra.Command, args []string) error { - composePath := filepath.Join(*dir, ".build", "docker-compose.yml") + buildDir := filepath.Join(*dir, ".build") + composePath := filepath.Join(buildDir, "docker-compose.yml") if _, err := os.Stat(composePath); os.IsNotExist(err) { return fmt.Errorf("%s not found — run 'agent-sandbox generate' first", composePath) } - composeArgs := []string{"-f", composePath, "--project-name", "agent-sandbox"} + // Load config to get container_runtime override; ignore errors + // (fleet mode or missing config still auto-detects from PATH). + // Priority: agent.yaml > fleet.yaml shared > auto-detect. + var containerRuntime string + if fleet, err := config.LoadFleet(*dir); err == nil { + containerRuntime = fleet.Shared.ContainerRuntime + } + if cfg, err := config.Load(*dir); err == nil && cfg.ContainerRuntime != "" { + containerRuntime = cfg.ContainerRuntime + } + rt, err := crt.DetectWithOverride(containerRuntime) + if err != nil { + return err + } + + // Fleet mode: expand sub-compose files as multiple -f flags + // instead of relying on the `include` directive (not supported by podman-compose) + composeFiles := expandFleetComposeFiles(buildDir, composePath) + + var composeArgs []string + for _, f := range composeFiles { + composeArgs = append(composeArgs, "-f", f) + } + composeArgs = append(composeArgs, "--project-name", "agent-sandbox") // Auto-inject --env-file if .env exists in project dir envPath := filepath.Join(*dir, ".env") if _, err := os.Stat(envPath); err == nil { composeArgs = append(composeArgs, "--env-file", envPath) } composeArgs = append(composeArgs, args...) - c := exec.Command("docker", append([]string{"compose"}, composeArgs...)...) + c := exec.Command(rt.ComposeCmd[0], append(rt.ComposeCmd[1:], composeArgs...)...) c.Stdin = os.Stdin c.Stdout = os.Stdout c.Stderr = os.Stderr @@ -296,6 +322,40 @@ func composeCmd(dir *string) *cobra.Command { return cmd } +// expandFleetComposeFiles checks if the compose file is a fleet umbrella +// (contains only include directives). If so, returns the individual sub-compose +// file paths. Otherwise returns the single compose file path. +func expandFleetComposeFiles(buildDir, composePath string) []string { + data, err := os.ReadFile(composePath) + if err != nil { + return []string{composePath} + } + + content := string(data) + if !strings.Contains(content, "include:") { + return []string{composePath} + } + + // Parse include entries (format: " - path/to/docker-compose.yml") + var files []string + for line := range strings.SplitSeq(content, "\n") { + line = strings.TrimSpace(line) + if after, ok := strings.CutPrefix(line, "- "); ok { + rel := after + rel = strings.TrimSpace(rel) + abs := filepath.Join(buildDir, rel) + if _, err := os.Stat(abs); err == nil { + files = append(files, abs) + } + } + } + + if len(files) == 0 { + return []string{composePath} + } + return files +} + func validateCmd(dir *string) *cobra.Command { return &cobra.Command{ Use: "validate", @@ -365,6 +425,7 @@ func describePlugin(name string, plugin resolve.FeaturePlugin) string { "claude-code": "Anthropic Claude Code runtime configuration", "pi": "Pi coding agent runtime configuration", "mcp-oauth": "OAuth token injection for remote MCP servers", + "ssh": "SSH server for remote development access", } if desc, ok := descriptions[name]; ok { return desc @@ -414,7 +475,7 @@ func initCmd() *cobra.Command { var features []string var envVars []string - for _, ch := range strings.Split(featureChoice, ",") { + for ch := range strings.SplitSeq(featureChoice, ",") { switch strings.TrimSpace(ch) { case "1": features = append(features, "github-pat") @@ -445,13 +506,13 @@ func initCmd() *cobra.Command { if username == "" { username = "@your_username" } - b.WriteString(" - plugin: telegram\n") - b.WriteString(" access_control:\n") - _, _ = fmt.Fprintf(&b, " allowed_users: [\"%s\"]\n", username) - case "custom-runtime": - b.WriteString(" - plugin: custom-runtime\n") - b.WriteString(" commands:\n") - b.WriteString(" - \"apt-get update && apt-get install -y --no-install-recommends ripgrep && rm -rf /var/lib/apt/lists/*\"\n") + b.WriteString(" - plugin: telegram\n") + b.WriteString(" access_control:\n") + _, _ = fmt.Fprintf(&b, " allowed_users: [\"%s\"]\n", username) + case "custom-runtime": + b.WriteString(" - plugin: custom-runtime\n") + b.WriteString(" commands:\n") + b.WriteString(" - \"apt-get update && apt-get install -y --no-install-recommends ripgrep && rm -rf /var/lib/apt/lists/*\"\n") } } } diff --git a/examples/local-coding-ssh/.env.example b/examples/local-coding-ssh/.env.example new file mode 100644 index 0000000..1386a54 --- /dev/null +++ b/examples/local-coding-ssh/.env.example @@ -0,0 +1,4 @@ +# Environment variables for agent-sandbox +# Copy to .env and fill in values + +STX_LLM_GATEWAY_API_KEY= diff --git a/examples/local-coding-ssh/.gitignore b/examples/local-coding-ssh/.gitignore new file mode 100644 index 0000000..2aab196 --- /dev/null +++ b/examples/local-coding-ssh/.gitignore @@ -0,0 +1,2 @@ +ssh_key* +*ssh_host_key* diff --git a/examples/local-coding-ssh/README.md b/examples/local-coding-ssh/README.md new file mode 100644 index 0000000..1bd5157 --- /dev/null +++ b/examples/local-coding-ssh/README.md @@ -0,0 +1,63 @@ +# Local Coding + SSH Example + +Extends the base `local-coding` example with SSH access into the agent container on port 2222. + +## Prerequisites + +Generate an SSH key pair for agent access: + +```bash +ssh-keygen -t ed25519 -f ssh_key -N "" +``` + +This creates `ssh_key` (private) and `ssh_key.pub` (public). The private key stays on your machine; the public key is mounted into the container as an authorized key. + +Both files are gitignored — do not commit real keys. + +## Setup + +```bash +cd examples/local-coding-ssh + +# Generate the SSH key pair (if not already done) +ssh-keygen -t ed25519 -f ssh_key -N "" + +# Generate build artifacts +agent-sandbox generate + +# Create .env from the example +cp .env.example .env +# Edit .env and fill in: +# STX_LLM_GATEWAY_API_KEY=your-api-key + +# Build and run +agent-sandbox compose up --build +``` + +## Connecting via SSH + +```bash +ssh -i ssh_key -p 2222 agent@localhost +``` + +### SSH Config (for Zed and other tools) + +Add to `~/.ssh/config`: + +``` +Host agent-sandbox + HostName localhost + Port 2222 + User agent + IdentityFile /path/to/examples/local-coding-ssh/ssh_key + StrictHostKeyChecking no + UserKnownHostsFile /dev/null +``` + +Then connect with `ssh agent-sandbox` or use the host name in Zed's SSH remote connections. + +## What's Included + +- **external-services** — gateway intercepts HTTP requests to `host.containers.internal:8000` and injects your real API key from `.env`. +- **ssh** — starts an OpenSSH server on port 2222 inside the container, using your generated public key for authentication. +- **custom-runtime** — overlays codex configuration (model catalog, provider settings) into the agent's home directory. diff --git a/examples/local-coding-ssh/agent.yaml b/examples/local-coding-ssh/agent.yaml new file mode 100644 index 0000000..06c7918 --- /dev/null +++ b/examples/local-coding-ssh/agent.yaml @@ -0,0 +1,19 @@ +# yaml-language-server: $schema=.build/schema.json +name: coder +runtime: codex +log_level: debug +features: + - plugin: external-services + services: + - url: http://host.containers.internal:8000/v1 + headers: + Authorization: Bearer ${STX_LLM_GATEWAY_API_KEY} + + - plugin: ssh + port: 2222 + authorized_keys: "./ssh_key.pub" + + - plugin: custom-runtime + home_override: "./home" + runtime_volumes: + - "agent-home:/home/agent" diff --git a/examples/local-coding-ssh/home/.codex/config.toml b/examples/local-coding-ssh/home/.codex/config.toml new file mode 100644 index 0000000..c288aa5 --- /dev/null +++ b/examples/local-coding-ssh/home/.codex/config.toml @@ -0,0 +1,18 @@ +# --- codex-switch:begin --- +model = "claude-opus-4.6" +model_provider = "agent_gateway_codex" +# --- codex-switch:end --- + +model_catalog_json = "/home/agent/.codex/models.json" + +[model_providers.agent_gateway_kiro] +name = "Agent Gateway (Kiro)" +base_url = "http://host.containers.internal:8000/v1" +http_headers = { Authorization = "Bearer dummy" } +wire_api = "responses" + +[model_providers.agent_gateway_codex] +name = "Agent Gateway (Codex)" +base_url = "http://host.containers.internal:8000/v1" +http_headers = { Authorization = "Bearer dummy" } +wire_api = "responses" diff --git a/examples/local-coding-ssh/home/.codex/models.json b/examples/local-coding-ssh/home/.codex/models.json new file mode 100644 index 0000000..f2a6fbd --- /dev/null +++ b/examples/local-coding-ssh/home/.codex/models.json @@ -0,0 +1,212 @@ +{ + "models": [ + { + "slug": "claude-opus-4.6", + "display_name": "Claude Opus 4.6", + "description": "Anthropic Claude Opus 4.6 via STX LLM Gateway", + "supported_reasoning_levels": [], + "shell_type": "shell_command", + "visibility": "list", + "supported_in_api": true, + "priority": 10, + "upgrade": null, + "base_instructions": "You are a helpful coding assistant.", + "model_messages": null, + "supports_reasoning_summaries": false, + "default_reasoning_summary": "auto", + "support_verbosity": false, + "default_verbosity": null, + "apply_patch_tool_type": null, + "truncation_policy": { "mode": "bytes", "limit": 10000 }, + "supports_parallel_tool_calls": true, + "supports_image_detail_original": false, + "context_window": 1000000, + "auto_compact_token_limit": null, + "effective_context_window_percent": 95, + "experimental_supported_tools": [], + "input_modalities": ["text", "image"] + }, + { + "slug": "claude-sonnet-4.6", + "display_name": "Claude Sonnet 4.6", + "description": "Anthropic Claude Sonnet 4.6 via STX LLM Gateway", + "supported_reasoning_levels": [], + "shell_type": "shell_command", + "visibility": "list", + "supported_in_api": true, + "priority": 9, + "upgrade": null, + "base_instructions": "You are a helpful coding assistant.", + "model_messages": null, + "supports_reasoning_summaries": false, + "default_reasoning_summary": "auto", + "support_verbosity": false, + "default_verbosity": null, + "apply_patch_tool_type": null, + "truncation_policy": { "mode": "bytes", "limit": 10000 }, + "supports_parallel_tool_calls": true, + "supports_image_detail_original": false, + "context_window": 1000000, + "auto_compact_token_limit": null, + "effective_context_window_percent": 95, + "experimental_supported_tools": [], + "input_modalities": ["text", "image"] + }, + { + "slug": "claude-haiku-4.5", + "display_name": "Claude Haiku 4.5", + "description": "Anthropic Claude Haiku 4.5 via STX LLM Gateway", + "supported_reasoning_levels": [], + "shell_type": "shell_command", + "visibility": "list", + "supported_in_api": true, + "priority": 8, + "upgrade": null, + "base_instructions": "You are a helpful coding assistant.", + "model_messages": null, + "supports_reasoning_summaries": false, + "default_reasoning_summary": "auto", + "support_verbosity": false, + "default_verbosity": null, + "apply_patch_tool_type": null, + "truncation_policy": { "mode": "bytes", "limit": 10000 }, + "supports_parallel_tool_calls": true, + "supports_image_detail_original": false, + "context_window": 200000, + "auto_compact_token_limit": null, + "effective_context_window_percent": 95, + "experimental_supported_tools": [], + "input_modalities": ["text", "image"] + }, + { + "slug": "deepseek-3.2", + "display_name": "Deepseek v3.2", + "description": "Deepseek v3.2 via STX LLM Gateway", + "supported_reasoning_levels": [], + "shell_type": "shell_command", + "visibility": "list", + "supported_in_api": true, + "priority": 7, + "upgrade": null, + "base_instructions": "You are a helpful coding assistant.", + "model_messages": null, + "supports_reasoning_summaries": false, + "default_reasoning_summary": "auto", + "support_verbosity": false, + "default_verbosity": null, + "apply_patch_tool_type": null, + "truncation_policy": { "mode": "bytes", "limit": 10000 }, + "supports_parallel_tool_calls": true, + "supports_image_detail_original": false, + "context_window": 128000, + "auto_compact_token_limit": null, + "effective_context_window_percent": 95, + "experimental_supported_tools": [], + "input_modalities": ["text"] + }, + { + "slug": "qwen3-coder-next", + "display_name": "Qwen3 Coder Next", + "description": "Qwen3 Coder Next via STX LLM Gateway", + "supported_reasoning_levels": [], + "shell_type": "shell_command", + "visibility": "list", + "supported_in_api": true, + "priority": 6, + "upgrade": null, + "base_instructions": "You are a helpful coding assistant.", + "model_messages": null, + "supports_reasoning_summaries": false, + "default_reasoning_summary": "auto", + "support_verbosity": false, + "default_verbosity": null, + "apply_patch_tool_type": null, + "truncation_policy": { "mode": "bytes", "limit": 10000 }, + "supports_parallel_tool_calls": true, + "supports_image_detail_original": false, + "context_window": 128000, + "auto_compact_token_limit": null, + "effective_context_window_percent": 95, + "experimental_supported_tools": [], + "input_modalities": ["text"] + }, + { + "slug": "glm-5", + "display_name": "GLM 5", + "description": "GLM 5 via STX LLM Gateway", + "supported_reasoning_levels": [], + "shell_type": "shell_command", + "visibility": "list", + "supported_in_api": true, + "priority": 5, + "upgrade": null, + "base_instructions": "You are a helpful coding assistant.", + "model_messages": null, + "supports_reasoning_summaries": false, + "default_reasoning_summary": "auto", + "support_verbosity": false, + "default_verbosity": null, + "apply_patch_tool_type": null, + "truncation_policy": { "mode": "bytes", "limit": 10000 }, + "supports_parallel_tool_calls": true, + "supports_image_detail_original": false, + "context_window": 128000, + "auto_compact_token_limit": null, + "effective_context_window_percent": 95, + "experimental_supported_tools": [], + "input_modalities": ["text"] + }, + { + "slug": "minimax-m2.1", + "display_name": "MiniMax M2.1", + "description": "MiniMax M2.1 via STX LLM Gateway", + "supported_reasoning_levels": [], + "shell_type": "shell_command", + "visibility": "list", + "supported_in_api": true, + "priority": 4, + "upgrade": null, + "base_instructions": "You are a helpful coding assistant.", + "model_messages": null, + "supports_reasoning_summaries": false, + "default_reasoning_summary": "auto", + "support_verbosity": false, + "default_verbosity": null, + "apply_patch_tool_type": null, + "truncation_policy": { "mode": "bytes", "limit": 10000 }, + "supports_parallel_tool_calls": true, + "supports_image_detail_original": false, + "context_window": 128000, + "auto_compact_token_limit": null, + "effective_context_window_percent": 95, + "experimental_supported_tools": [], + "input_modalities": ["text"] + }, + { + "slug": "minimax-m2.5", + "display_name": "MiniMax M2.5", + "description": "MiniMax M2.5 via STX LLM Gateway", + "supported_reasoning_levels": [], + "shell_type": "shell_command", + "visibility": "list", + "supported_in_api": true, + "priority": 3, + "upgrade": null, + "base_instructions": "You are a helpful coding assistant.", + "model_messages": null, + "supports_reasoning_summaries": false, + "default_reasoning_summary": "auto", + "support_verbosity": false, + "default_verbosity": null, + "apply_patch_tool_type": null, + "truncation_policy": { "mode": "bytes", "limit": 10000 }, + "supports_parallel_tool_calls": true, + "supports_image_detail_original": false, + "context_window": 128000, + "auto_compact_token_limit": null, + "effective_context_window_percent": 95, + "experimental_supported_tools": [], + "input_modalities": ["text"] + } + ] +} diff --git a/examples/local-coding-ssh/home/.config/zed/settings.json b/examples/local-coding-ssh/home/.config/zed/settings.json new file mode 100644 index 0000000..30167e2 --- /dev/null +++ b/examples/local-coding-ssh/home/.config/zed/settings.json @@ -0,0 +1,9 @@ +{ + "agent_servers": { + "codex-acp": { + "type": "custom", + "command": "codex-acp", + "args": [], + }, + }, +} diff --git a/examples/local-coding-ssh/scripts/ssh-perms.sh b/examples/local-coding-ssh/scripts/ssh-perms.sh new file mode 100755 index 0000000..7b6cfbb --- /dev/null +++ b/examples/local-coding-ssh/scripts/ssh-perms.sh @@ -0,0 +1,4 @@ +#!/bin/bash +set -e +chmod 700 /home/agent/.ssh +chmod 600 /home/agent/.ssh/authorized_keys diff --git a/examples/local-coding-ssh/scripts/ssh-root-setup.sh b/examples/local-coding-ssh/scripts/ssh-root-setup.sh new file mode 100755 index 0000000..5420fc0 --- /dev/null +++ b/examples/local-coding-ssh/scripts/ssh-root-setup.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -e +cp /run/ssh/host_key /etc/ssh/ssh_host_ed25519_key +chmod 600 /etc/ssh/ssh_host_ed25519_key +ssh-keygen -y -f /etc/ssh/ssh_host_ed25519_key > /etc/ssh/ssh_host_ed25519_key.pub +mkdir -p /home/agent/.ssh +cp /run/ssh/authorized_keys /home/agent/.ssh/authorized_keys +chown -R agent:agent /home/agent/.ssh +/usr/sbin/sshd -p 2222 diff --git a/examples/local-coding/agent.yaml b/examples/local-coding/agent.yaml index 61727c4..6464d3e 100644 --- a/examples/local-coding/agent.yaml +++ b/examples/local-coding/agent.yaml @@ -11,3 +11,5 @@ features: - plugin: custom-runtime home_override: "./home" + runtime_volumes: + - "agent-home:/home/agent" diff --git a/examples/local-coding/home/.codex/config.toml b/examples/local-coding/home/.codex/config.toml index d89e265..36abc49 100644 --- a/examples/local-coding/home/.codex/config.toml +++ b/examples/local-coding/home/.codex/config.toml @@ -1,6 +1,6 @@ # --- codex-switch:begin --- -model = "claude-sonnet-4.6" -model_provider = "agent_gateway_kiro" +model = "claude-opus-4.6" +model_provider = "agent_gateway_codex" # --- codex-switch:end --- model_catalog_json = "/home/agent/.codex/models.json" diff --git a/gateway/internal/mitm/auth_header.go b/gateway/internal/mitm/auth_header.go index a8c5f67..79a62bd 100644 --- a/gateway/internal/mitm/auth_header.go +++ b/gateway/internal/mitm/auth_header.go @@ -6,6 +6,7 @@ import ( "net" "net/http" "os" + "slices" "strings" ) @@ -46,21 +47,18 @@ func NewAuthHeaderRewriter(domains []string, header, valueFormat, envVar string) } // RewriteRequest injects the configured header if the request host matches one of the -// configured domains. Returns true if the header was injected. +// configured domains. Supports both bare hostnames ("api.github.com") and host:port +// entries ("host.internal:8000") for port-aware matching. Returns true if the header was injected. func (r *AuthHeaderRewriter) RewriteRequest(req *http.Request) bool { host := req.Host - // Strip port if present (e.g., "api.github.com:443" → "api.github.com") + // Extract bare hostname for fallback matching + bareHost := host if h, _, err := net.SplitHostPort(host); err == nil { - host = h + bareHost = h } - matched := false - for _, d := range r.domains { - if host == d { - matched = true - break - } - } + // Match against host:port first, then bare hostname + matched := slices.Contains(r.domains, host) || slices.Contains(r.domains, bareHost) if !matched { return false } diff --git a/gateway/internal/mitm/mitm.go b/gateway/internal/mitm/mitm.go index 89d46d4..9f1a6c3 100644 --- a/gateway/internal/mitm/mitm.go +++ b/gateway/internal/mitm/mitm.go @@ -10,6 +10,7 @@ import ( "log/slog" "net" "net/http" + "slices" "strings" "sync" ) @@ -42,12 +43,7 @@ func NewHandler(domains []string, caCert tls.Certificate, rewriters []Rewriter) // Matches returns true if the host is in the MITM domain list. func (h *Handler) Matches(host string) bool { - for _, d := range h.domains { - if host == d { - return true - } - } - return false + return slices.Contains(h.domains, host) } // Handle terminates TLS, parses HTTP, applies rewriters, and forwards. diff --git a/gateway/internal/mitm/oauth.go b/gateway/internal/mitm/oauth.go index 6bbe56c..bce2cba 100644 --- a/gateway/internal/mitm/oauth.go +++ b/gateway/internal/mitm/oauth.go @@ -14,6 +14,7 @@ import ( "net/http" "net/url" "os" + "slices" "strings" "sync" "time" @@ -21,12 +22,12 @@ import ( // StoredToken represents a persisted OAuth token (written by setup, read/updated by this rewriter). type StoredToken struct { - AccessToken string `json:"access_token"` - RefreshToken *string `json:"refresh_token"` - ExpiresAt int64 `json:"expires_at"` - TokenEndpoint string `json:"token_endpoint"` - ClientID string `json:"client_id"` - ClientSecret *string `json:"client_secret"` + AccessToken string `json:"access_token"` + RefreshToken *string `json:"refresh_token"` + ExpiresAt int64 `json:"expires_at"` + TokenEndpoint string `json:"token_endpoint"` + ClientID string `json:"client_id"` + ClientSecret *string `json:"client_secret"` } // OAuthRewriter injects a Bearer token into requests destined for specific domains. @@ -36,10 +37,10 @@ type OAuthRewriter struct { domains []string tokenFile string - mu sync.Mutex - cachedToken *StoredToken - cachedUntil time.Time - httpClient *http.Client + mu sync.Mutex + cachedToken *StoredToken + cachedUntil time.Time + httpClient *http.Client } // NewOAuthRewriter creates a rewriter that reads an OAuth token file and injects @@ -68,20 +69,16 @@ func NewOAuthRewriter(domains []string, tokenFile string) (*OAuthRewriter, error } // RewriteRequest injects a Bearer Authorization header if the request host matches -// one of the configured domains. Returns true if the header was injected. +// one of the configured domains. Supports both bare hostnames and host:port entries +// for port-aware matching. Returns true if the header was injected. func (r *OAuthRewriter) RewriteRequest(req *http.Request) bool { host := req.Host + bareHost := host if h, _, err := net.SplitHostPort(host); err == nil { - host = h + bareHost = h } - matched := false - for _, d := range r.domains { - if host == d { - matched = true - break - } - } + matched := slices.Contains(r.domains, host) || slices.Contains(r.domains, bareHost) if !matched { return false } @@ -134,10 +131,7 @@ func (r *OAuthRewriter) getValidToken() (string, error) { } // Cache until 5 minutes before expiry (minimum 60 seconds). - ttl := stored.ExpiresAt - now - 300 - if ttl < 60 { - ttl = 60 - } + ttl := max(stored.ExpiresAt-now-300, 60) r.cachedToken = stored r.cachedUntil = time.Now().Add(time.Duration(ttl) * time.Second) diff --git a/gateway/internal/mitm/oauth_test.go b/gateway/internal/mitm/oauth_test.go index f4acb70..944a079 100644 --- a/gateway/internal/mitm/oauth_test.go +++ b/gateway/internal/mitm/oauth_test.go @@ -17,7 +17,7 @@ import ( func TestOAuthRewriter_InjectsBearer(t *testing.T) { tokenFile := writeTestToken(t, &StoredToken{ AccessToken: "test-access-token", - RefreshToken: strPtr("test-refresh"), + RefreshToken: new("test-refresh"), ExpiresAt: time.Now().Unix() + 3600, TokenEndpoint: "https://example.com/token", ClientID: "client-id", @@ -38,7 +38,7 @@ func TestOAuthRewriter_InjectsBearer(t *testing.T) { func TestOAuthRewriter_SkipsNonMatchingDomain(t *testing.T) { tokenFile := writeTestToken(t, &StoredToken{ AccessToken: "token", - RefreshToken: strPtr("refresh"), + RefreshToken: new("refresh"), ExpiresAt: time.Now().Unix() + 3600, TokenEndpoint: "https://example.com/token", ClientID: "cid", @@ -77,7 +77,7 @@ func TestOAuthRewriter_RefreshesExpiredToken(t *testing.T) { tokenFile := writeTestToken(t, &StoredToken{ AccessToken: "expired-token", - RefreshToken: strPtr("old-refresh"), + RefreshToken: new("old-refresh"), ExpiresAt: time.Now().Unix() - 100, // Already expired. TokenEndpoint: server.URL, // https://127.0.0.1:PORT ClientID: "client-id", @@ -108,7 +108,7 @@ func TestOAuthRewriter_RefreshesExpiredToken(t *testing.T) { func TestOAuthRewriter_RejectsHTTPTokenEndpoint(t *testing.T) { tokenFile := writeTestToken(t, &StoredToken{ AccessToken: "expired-token", - RefreshToken: strPtr("refresh"), + RefreshToken: new("refresh"), ExpiresAt: time.Now().Unix() - 100, // Expired — triggers refresh. TokenEndpoint: "http://evil.example.com/token", ClientID: "cid", @@ -130,7 +130,7 @@ func TestOAuthRewriter_RejectsHTTPTokenEndpoint(t *testing.T) { func TestOAuthRewriter_BlocksPrivateIPEndpoint(t *testing.T) { tokenFile := writeTestToken(t, &StoredToken{ AccessToken: "expired-token", - RefreshToken: strPtr("refresh"), + RefreshToken: new("refresh"), ExpiresAt: time.Now().Unix() - 100, // Expired — triggers refresh. TokenEndpoint: "https://127.0.0.1:9999/token", ClientID: "cid", @@ -206,7 +206,7 @@ func TestOAuthRewriter_ErrorsWithoutTokenFile(t *testing.T) { func TestOAuthRewriter_HandlesHostWithPort(t *testing.T) { tokenFile := writeTestToken(t, &StoredToken{ AccessToken: "port-token", - RefreshToken: strPtr("refresh"), + RefreshToken: new("refresh"), ExpiresAt: time.Now().Unix() + 3600, TokenEndpoint: "https://example.com/token", ClientID: "cid", @@ -227,7 +227,7 @@ func TestOAuthRewriter_HandlesHostWithPort(t *testing.T) { func TestOAuthRewriter_CachesToken(t *testing.T) { tokenFile := writeTestToken(t, &StoredToken{ AccessToken: "cached-token", - RefreshToken: strPtr("refresh"), + RefreshToken: new("refresh"), ExpiresAt: time.Now().Unix() + 3600, TokenEndpoint: "https://example.com/token", ClientID: "cid", @@ -264,6 +264,7 @@ func writeTestToken(t *testing.T, token *StoredToken) string { return path } +//go:fix inline func strPtr(s string) *string { - return &s + return new(s) } diff --git a/gateway/internal/proxy/config.go b/gateway/internal/proxy/config.go index f1db1c6..ce64f9e 100644 --- a/gateway/internal/proxy/config.go +++ b/gateway/internal/proxy/config.go @@ -23,6 +23,7 @@ type RewriterConfig struct { // Config holds gateway configuration. type Config struct { Listen string `yaml:"listen"` // TCP listen address (e.g., ":8443") + HTTPListen string `yaml:"http_listen"` // HTTP proxy listen address (e.g., ":8080") DNSListen string `yaml:"dns_listen"` // DNS listen address (e.g., ":53") MITMDomains []string `yaml:"mitm_domains"` // domains to MITM (terminate TLS) HTTPServices []HTTPService `yaml:"http_services"` // plain HTTP services to proxy @@ -67,6 +68,9 @@ func LoadConfig(path string) (*Config, error) { if cfg.Listen == "" { cfg.Listen = ":8443" } + if cfg.HTTPListen == "" { + cfg.HTTPListen = ":8080" + } if cfg.DNSListen == "" { cfg.DNSListen = ":53" } diff --git a/gateway/internal/proxy/http_proxy.go b/gateway/internal/proxy/http_proxy.go new file mode 100644 index 0000000..309dc3d --- /dev/null +++ b/gateway/internal/proxy/http_proxy.go @@ -0,0 +1,116 @@ +// Package proxy implements transparent proxies for the gateway. +// http_proxy.go provides a transparent HTTP reverse proxy that intercepts +// plain HTTP requests redirected via iptables, applies rewriters (auth-header +// injection), and forwards upstream. +package proxy + +import ( + "fmt" + "log/slog" + "net" + "net/http" + "net/http/httputil" + "time" + + "github.com/donbader/agent-sandbox/gateway/internal/mitm" +) + +// HTTPProxy is a transparent HTTP reverse proxy that intercepts plain HTTP +// traffic redirected via iptables, applies rewriters, and forwards upstream. +type HTTPProxy struct { + listenAddr string + domains []string + rewriters []mitm.Rewriter + transport *http.Transport +} + +// NewHTTPProxy creates a new HTTP proxy that intercepts requests for the given +// domains and applies rewriters before forwarding. +func NewHTTPProxy(listenAddr string, domains []string, rewriters []mitm.Rewriter) *HTTPProxy { + return &HTTPProxy{ + listenAddr: listenAddr, + domains: domains, + rewriters: rewriters, + transport: &http.Transport{ + DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + DisableCompression: true, + MaxIdleConnsPerHost: 10, + }, + } +} + +// ListenAndServe starts the HTTP proxy listener. +func (h *HTTPProxy) ListenAndServe() error { + server := &http.Server{ + Addr: h.listenAddr, + Handler: h, + ReadTimeout: 30 * time.Second, + WriteTimeout: 60 * time.Second, + } + return server.ListenAndServe() +} + +// ServeHTTP handles each proxied HTTP request. +func (h *HTTPProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) { + host := req.Host + + if !h.matchesDomain(host) { + slog.Debug("http proxy: domain not matched, passing through", "host", host) + } + + // Apply rewriters + rewritten := false + for _, rw := range h.rewriters { + if rw.RewriteRequest(req) { + rewritten = true + } + } + slog.Debug("http proxy request", "host", host, "method", req.Method, "path", req.URL.Path, "rewritten", rewritten) + + // Determine upstream target — use the original Host header (includes port) + target := host + if _, _, err := net.SplitHostPort(target); err != nil { + // No port specified, default to 80 + target = net.JoinHostPort(target, "80") + } + + // Forward via reverse proxy + proxy := &httputil.ReverseProxy{ + Director: func(outReq *http.Request) { + outReq.URL.Scheme = "http" + outReq.URL.Host = target + outReq.Host = req.Host + }, + Transport: h.transport, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + slog.Error("http proxy upstream error", "host", host, "error", err) + http.Error(w, fmt.Sprintf("gateway: upstream error: %v", err), http.StatusBadGateway) + }, + } + + proxy.ServeHTTP(w, req) +} + +// matchesDomain checks if the request host is in the configured HTTP domain list. +// Both sides are normalized to host:port (defaulting to port 80) for comparison. +func (h *HTTPProxy) matchesDomain(requestHost string) bool { + normalized := normalizeHostPort(requestHost) + for _, d := range h.domains { + if normalizeHostPort(d) == normalized { + return true + } + } + return false +} + +// normalizeHostPort ensures s is in host:port form, defaulting to port 80. +func normalizeHostPort(s string) string { + host, port, err := net.SplitHostPort(s) + if err != nil { + // No port — default to 80 + return net.JoinHostPort(s, "80") + } + return net.JoinHostPort(host, port) +} diff --git a/gateway/internal/proxy/http_proxy_test.go b/gateway/internal/proxy/http_proxy_test.go new file mode 100644 index 0000000..5c08f50 --- /dev/null +++ b/gateway/internal/proxy/http_proxy_test.go @@ -0,0 +1,104 @@ +package proxy_test + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/donbader/agent-sandbox/gateway/internal/mitm" + "github.com/donbader/agent-sandbox/gateway/internal/proxy" +) + +func TestHTTPProxy_ForwardsRequestAndAppliesRewriter(t *testing.T) { + // Upstream server that echoes back the injected header + var receivedHeader string + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeader = r.Header.Get("X-Test-Token") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer upstream.Close() + + // Extract host:port from upstream URL + upstreamHost := upstream.Listener.Addr().String() + + rewriter := &alwaysRewriter{ + header: "X-Test-Token", + value: "secret-123", + } + + hp := proxy.NewHTTPProxy(":0", []string{upstreamHost}, []mitm.Rewriter{rewriter}) + + // Use httptest to test the handler directly + req := httptest.NewRequest(http.MethodGet, "http://"+upstreamHost+"/test-path", nil) + req.Host = upstreamHost + + rec := httptest.NewRecorder() + hp.ServeHTTP(rec, req) + + // The rewriter should have injected the header before forwarding + if receivedHeader != "secret-123" { + t.Errorf("expected upstream to receive header 'secret-123', got %q", receivedHeader) + } + + resp := rec.Result() + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if string(body) != "ok" { + t.Errorf("expected body 'ok', got %q", string(body)) + } +} + +func TestHTTPProxy_PreservesHostHeader(t *testing.T) { + var receivedHost string + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHost = r.Host + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + upstreamHost := upstream.Listener.Addr().String() + hp := proxy.NewHTTPProxy(":0", []string{upstreamHost}, nil) + + req := httptest.NewRequest(http.MethodGet, "http://"+upstreamHost+"/path", nil) + req.Host = upstreamHost + + rec := httptest.NewRecorder() + hp.ServeHTTP(rec, req) + + if receivedHost != upstreamHost { + t.Errorf("expected Host header %q, got %q", upstreamHost, receivedHost) + } +} + +func TestHTTPProxy_UpstreamError(t *testing.T) { + // Point at an address that refuses connections + hp := proxy.NewHTTPProxy(":0", []string{"127.0.0.1:1"}, nil) + + req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1:1/path", nil) + req.Host = "127.0.0.1:1" + + rec := httptest.NewRecorder() + hp.ServeHTTP(rec, req) + + resp := rec.Result() + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadGateway { + t.Errorf("expected 502, got %d", resp.StatusCode) + } +} + +// alwaysRewriter injects a header on every request regardless of domain. +type alwaysRewriter struct { + header string + value string +} + +func (r *alwaysRewriter) RewriteRequest(req *http.Request) bool { + req.Header.Set(r.header, r.value) + return true +} diff --git a/gateway/internal/proxy/sni.go b/gateway/internal/proxy/sni.go index f34eb0d..605e474 100644 --- a/gateway/internal/proxy/sni.go +++ b/gateway/internal/proxy/sni.go @@ -48,10 +48,7 @@ func extractSNI(data []byte) string { extensionsLen := int(data[pos])<<8 | int(data[pos+1]) pos += 2 - end := pos + extensionsLen - if end > len(data) { - end = len(data) - } + end := min(pos+extensionsLen, len(data)) for pos+4 <= end { extType := int(data[pos])<<8 | int(data[pos+1]) diff --git a/internal/config/config.go b/internal/config/config.go index 7933598..9d73f5b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -53,12 +53,13 @@ func (f *FeatureEntry) UnmarshalYAML(value *yaml.Node) error { // AgentConfig represents an agent.yaml file. type AgentConfig struct { - Name string `yaml:"name" schema:"Agent name" required:"true" examples:"my-agent"` - Runtime string `yaml:"runtime" schema:"Runtime plugin name" required:"true" enum:"codex,claude-code,pi"` - LogLevel string `yaml:"log_level" schema:"Log verbosity level" default:"info" enum:"info,debug"` - Gateway *bool `yaml:"gateway" schema:"Enable transparent gateway proxy" default:"true"` - Workdir string `yaml:"workdir" schema:"Working directory for the agent. Supports {{ .AGENT_HOME }} template variable." examples:"{{ .AGENT_HOME }}/workspace"` - Features []FeatureEntry `yaml:"features" schema:"Feature plugins and their configuration"` + Name string `yaml:"name" schema:"Agent name" required:"true" examples:"my-agent"` + Runtime string `yaml:"runtime" schema:"Runtime plugin name" required:"true" enum:"codex"` + ContainerRuntime string `yaml:"container_runtime" schema:"Container runtime override" enum:"docker,podman"` + LogLevel string `yaml:"log_level" schema:"Log verbosity level" default:"info" enum:"info,debug"` + Gateway *bool `yaml:"gateway" schema:"Enable transparent gateway proxy" default:"true"` + Workdir string `yaml:"workdir" schema:"Working directory inside the container" examples:"/workspace"` + Features []FeatureEntry `yaml:"features" schema:"Feature plugins and their configuration"` } // GatewayEnabled returns whether the gateway should be included. @@ -101,7 +102,8 @@ type FleetConfig struct { // SharedBlock holds features shared across all agents. type SharedBlock struct { - Features []FeatureEntry `yaml:"features"` + ContainerRuntime string `yaml:"container_runtime" schema:"Container runtime override" enum:"docker,podman"` + Features []FeatureEntry `yaml:"features"` } // LoadFleet reads and parses a fleet.yaml file from the given directory. diff --git a/internal/generate/buildspec.go b/internal/generate/buildspec.go index d8b29e3..620aeff 100644 --- a/internal/generate/buildspec.go +++ b/internal/generate/buildspec.go @@ -3,10 +3,11 @@ package generate // GatewaySpec defines how the gateway container is built and configured. // Injected into Generator by the CLI — generator doesn't own these details. type GatewaySpec struct { - BuildImage string // Docker image for compilation (e.g. "golang:1.26.4-alpine") - BinaryPath string // output binary path (e.g. "/gateway") - ListenPort int // TLS interception port (e.g. 8443) - DNSPort int // DNS resolver port (e.g. 5353) + BuildImage string // Docker image for compilation (e.g. "golang:1.26.4-alpine") + BinaryPath string // output binary path (e.g. "/gateway") + ListenPort int // TLS interception port (e.g. 8443) + HTTPListenPort int // HTTP proxy port (e.g. 8080) + DNSPort int // DNS resolver port (e.g. 5353) } // ChannelManagerSpec defines how the channel manager is built and configured. diff --git a/internal/generate/channel_manager.go b/internal/generate/channel_manager.go index 8143fdd..c1ee2d7 100644 --- a/internal/generate/channel_manager.go +++ b/internal/generate/channel_manager.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io/fs" + "maps" "os" "path/filepath" "strings" @@ -255,9 +256,7 @@ func (g *Generator) writeChannelConfig() error { // Pass plugin-specific config to channel-manager (generic — no plugin knowledge here) for _, f := range g.Features { - for k, v := range f.ChannelConfig { - config[k] = v - } + maps.Copy(config, f.ChannelConfig) } data, err := json.MarshalIndent(config, "", " ") diff --git a/internal/generate/compose.go b/internal/generate/compose.go index b31dfd0..dbb1a19 100644 --- a/internal/generate/compose.go +++ b/internal/generate/compose.go @@ -12,7 +12,8 @@ type ComposeBuilder struct { AgentName string GatewayName string LogLevel string - Ports []string + Ports []string // runtime ports (on gateway in gateway mode, on agent in single mode) + AgentPorts []string // feature-contributed ports (always on agent) Volumes []string NamedVolumes []string EnvVars []string @@ -20,20 +21,28 @@ type ComposeBuilder struct { HasMITM bool GatewayCertDir string ExternalNetworks []string + Capabilities []string // additional capabilities from features } // buildComposeBuilder constructs a ComposeBuilder from the Generator state. func (g *Generator) buildComposeBuilder() *ComposeBuilder { + // Merge runtime ports with feature-contributed ports + runtimePorts := append([]string{}, g.Runtime.Ports...) + featurePorts := g.collectFeaturePorts() + cb := &ComposeBuilder{ - AgentName: g.Config.Name, - LogLevel: g.logLevel(), - Ports: g.Runtime.Ports, - EnvVars: g.mergedEnvVars(), + AgentName: g.Config.Name, + LogLevel: g.logLevel(), + Ports: append(runtimePorts, featurePorts...), + AgentPorts: featurePorts, + EnvVars: g.mergedEnvVars(), + Capabilities: g.collectCapabilities(), } if g.Gateway { cb.Variant = "gateway" cb.GatewayName = g.Config.Name + "-gateway" + cb.Ports = runtimePorts // gateway only gets runtime ports cb.AgentEnv = g.collectAgentEnv() cb.HasMITM = g.hasMITMDomains() cb.GatewayCertDir = gatewayCertDir diff --git a/internal/generate/dockerfile.go b/internal/generate/dockerfile.go index 9160782..a89d446 100644 --- a/internal/generate/dockerfile.go +++ b/internal/generate/dockerfile.go @@ -42,7 +42,7 @@ func NewDockerfileBuilder(g *Generator, variant string) *DockerfileBuilder { Install: g.Runtime.Install, HasEntrypoint: g.needsEntrypoint(), HasHomeOverride: g.hasHomeOverride(), - HasHooks: g.hasHooks(), + HasHooks: g.hasHooks() || g.hasRootHooks(), Cmd: g.Runtime.Cmd, VolumePaths: g.collectVolumePaths(), Workdir: g.Workdir, diff --git a/internal/generate/entrypoint.go b/internal/generate/entrypoint.go index 032db21..2234ace 100644 --- a/internal/generate/entrypoint.go +++ b/internal/generate/entrypoint.go @@ -21,6 +21,8 @@ type EntrypointBuilder struct { HasHomeOverride bool HasHooks bool Hooks []string // base filenames only + HasRootHooks bool + RootHooks []string // base filenames only User string Ports []PortMapping RuntimeCmd string @@ -100,6 +102,14 @@ func (g *Generator) writeAgentEntrypoint() error { } } + // Collect root hooks (base filenames only) + var rootHooks []string + for _, f := range g.Features { + for _, hook := range f.RootHooks { + rootHooks = append(rootHooks, filepath.Base(hook)) + } + } + // Collect port mappings var ports []PortMapping if g.Gateway && len(g.Runtime.Ports) > 0 { @@ -116,6 +126,8 @@ func (g *Generator) writeAgentEntrypoint() error { HasHomeOverride: g.hasHomeOverride(), HasHooks: len(hooks) > 0, Hooks: hooks, + HasRootHooks: len(rootHooks) > 0, + RootHooks: rootHooks, User: g.Runtime.User, Ports: ports, RuntimeCmd: strings.Join(g.Runtime.Cmd, " "), diff --git a/internal/generate/gateway_config.go b/internal/generate/gateway_config.go index 3c905d1..b6d4fa9 100644 --- a/internal/generate/gateway_config.go +++ b/internal/generate/gateway_config.go @@ -2,8 +2,10 @@ package generate import ( "fmt" + "net/url" "os" "path/filepath" + "strings" "github.com/donbader/agent-sandbox/internal/resolve" ) @@ -17,22 +19,26 @@ type PortForward struct { // GatewayConfigBuilder holds data for rendering the gateway-config.yaml template. type GatewayConfigBuilder struct { - ListenPort int - DNSPort int - MITMDomains []string - HTTPServices []resolve.HTTPService - Rewriters []resolve.RewriterConfig - PortForwards []PortForward + ListenPort int + HTTPListenPort int + DNSPort int + MITMDomains []string + HTTPServices []resolve.HTTPService + Rewriters []resolve.RewriterConfig + PortForwards []PortForward } // buildGatewayConfigBuilder constructs a GatewayConfigBuilder from the Generator state. func (g *Generator) buildGatewayConfigBuilder() *GatewayConfigBuilder { + mitmDomains, _ := g.splitDomainsByScheme() + gcb := &GatewayConfigBuilder{ - ListenPort: g.GatewaySpec.ListenPort, - DNSPort: g.GatewaySpec.DNSPort, - MITMDomains: g.collectMITMDomains(), - HTTPServices: g.collectHTTPServices(), - Rewriters: g.collectRewriters(), + ListenPort: g.GatewaySpec.ListenPort, + HTTPListenPort: g.GatewaySpec.HTTPListenPort, + DNSPort: g.GatewaySpec.DNSPort, + MITMDomains: mitmDomains, + HTTPServices: g.collectHTTPServices(), + Rewriters: g.collectRewriters(), } for _, p := range g.Runtime.Ports { @@ -47,6 +53,35 @@ func (g *Generator) buildGatewayConfigBuilder() *GatewayConfigBuilder { return gcb } +// splitDomainsByScheme separates MITM domain entries into TLS (no scheme or https://) +// and HTTP (http://) groups. HTTP entries are stripped of their scheme for the config. +func (g *Generator) splitDomainsByScheme() (mitmDomains, httpDomains []string) { + for _, d := range g.collectMITMDomains() { + if strings.HasPrefix(d, "http://") { + // Strip scheme, keep host (and port if present) + parsed, err := url.Parse(d) + if err != nil { + // Malformed — treat as MITM domain + mitmDomains = append(mitmDomains, d) + continue + } + httpDomains = append(httpDomains, parsed.Host) + } else if strings.HasPrefix(d, "https://") { + // Strip scheme for MITM list + parsed, err := url.Parse(d) + if err != nil { + mitmDomains = append(mitmDomains, d) + continue + } + mitmDomains = append(mitmDomains, parsed.Host) + } else { + // No scheme — default to MITM (TLS) + mitmDomains = append(mitmDomains, d) + } + } + return mitmDomains, httpDomains +} + // writeGatewayConfig generates .build/gateway-config.yaml using a template. func (g *Generator) writeGatewayConfig() error { gcb := g.buildGatewayConfigBuilder() diff --git a/internal/generate/gateway_config_test.go b/internal/generate/gateway_config_test.go index e2c1435..042f2eb 100644 --- a/internal/generate/gateway_config_test.go +++ b/internal/generate/gateway_config_test.go @@ -180,12 +180,13 @@ func TestBuildGatewayConfigBuilder(t *testing.T) { }, }, }, - GatewaySpec: GatewaySpec{ListenPort: 8443, DNSPort: 5353}, + GatewaySpec: GatewaySpec{ListenPort: 8443, HTTPListenPort: 8080, DNSPort: 5353}, } gcb := g.buildGatewayConfigBuilder() assert.Equal(t, 8443, gcb.ListenPort) + assert.Equal(t, 8080, gcb.HTTPListenPort) assert.Equal(t, 5353, gcb.DNSPort) assert.Equal(t, []string{"api.telegram.org"}, gcb.MITMDomains) assert.Len(t, gcb.Rewriters, 1) @@ -196,3 +197,154 @@ func TestBuildGatewayConfigBuilder(t *testing.T) { assert.Equal(t, "coder", gcb.PortForwards[0].AgentName) }) } + +func TestSplitDomainsByScheme(t *testing.T) { + t.Run("no scheme defaults to MITM", func(t *testing.T) { + g := &Generator{ + Features: []*resolve.FeatureContributions{ + {MITMDomains: []string{"api.github.com", "api.telegram.org"}}, + }, + } + + mitm, httpDomains := g.splitDomainsByScheme() + + assert.Equal(t, []string{"api.github.com", "api.telegram.org"}, mitm) + assert.Empty(t, httpDomains) + }) + + t.Run("http:// scheme goes to HTTP domains", func(t *testing.T) { + g := &Generator{ + Features: []*resolve.FeatureContributions{ + {MITMDomains: []string{"http://host.containers.internal:8000"}}, + }, + } + + mitm, httpDomains := g.splitDomainsByScheme() + + assert.Empty(t, mitm) + assert.Equal(t, []string{"host.containers.internal:8000"}, httpDomains) + }) + + t.Run("https:// scheme goes to MITM domains (stripped)", func(t *testing.T) { + g := &Generator{ + Features: []*resolve.FeatureContributions{ + {MITMDomains: []string{"https://api.github.com"}}, + }, + } + + mitm, httpDomains := g.splitDomainsByScheme() + + assert.Equal(t, []string{"api.github.com"}, mitm) + assert.Empty(t, httpDomains) + }) + + t.Run("mixed schemes are separated correctly", func(t *testing.T) { + g := &Generator{ + Features: []*resolve.FeatureContributions{ + {MITMDomains: []string{ + "api.github.com", + "http://host.containers.internal:8000", + "https://api.telegram.org", + "http://host.containers.internal:9000", + }}, + }, + } + + mitm, httpDomains := g.splitDomainsByScheme() + + assert.Equal(t, []string{"api.github.com", "api.telegram.org"}, mitm) + assert.Equal(t, []string{"host.containers.internal:8000", "host.containers.internal:9000"}, httpDomains) + }) +} + +func TestCollectHTTPPorts(t *testing.T) { + t.Run("extracts ports from HTTP services", func(t *testing.T) { + g := &Generator{ + Features: []*resolve.FeatureContributions{ + {HTTPServices: []resolve.HTTPService{ + {Host: "host.containers.internal", Port: "8000"}, + {Host: "host.containers.internal", Port: "9000"}, + }}, + }, + } + + ports := g.collectHTTPPorts() + + assert.Equal(t, []string{"8000", "9000"}, ports) + }) + + t.Run("defaults to port 80 when no port specified", func(t *testing.T) { + g := &Generator{ + Features: []*resolve.FeatureContributions{ + {HTTPServices: []resolve.HTTPService{ + {Host: "example.com", Port: ""}, + }}, + }, + } + + ports := g.collectHTTPPorts() + + assert.Equal(t, []string{"80"}, ports) + }) + + t.Run("deduplicates ports", func(t *testing.T) { + g := &Generator{ + Features: []*resolve.FeatureContributions{ + {HTTPServices: []resolve.HTTPService{ + {Host: "host1.internal", Port: "8000"}, + {Host: "host2.internal", Port: "8000"}, + }}, + }, + } + + ports := g.collectHTTPPorts() + + assert.Equal(t, []string{"8000"}, ports) + }) + + t.Run("no HTTP services returns empty", func(t *testing.T) { + g := &Generator{ + Features: []*resolve.FeatureContributions{ + {MITMDomains: []string{"api.github.com"}}, + }, + } + + ports := g.collectHTTPPorts() + + assert.Empty(t, ports) + }) +} + +func TestGatewayConfigBuilder_HTTPServices(t *testing.T) { + t.Run("renders HTTP services in config", func(t *testing.T) { + gcb := &GatewayConfigBuilder{ + ListenPort: 8443, + HTTPListenPort: 8080, + DNSPort: 5353, + HTTPServices: []resolve.HTTPService{ + {Host: "host.containers.internal", Port: "8000"}, + }, + } + + content, err := renderTemplate("gateway-config.yaml.tmpl", gcb) + require.NoError(t, err) + + assert.Contains(t, content, `http_listen: ":8080"`) + assert.Contains(t, content, "http_services:") + assert.Contains(t, content, "host.containers.internal") + assert.Contains(t, content, `port: "8000"`) + }) + + t.Run("omits http_services when empty", func(t *testing.T) { + gcb := &GatewayConfigBuilder{ + ListenPort: 8443, + HTTPListenPort: 8080, + DNSPort: 5353, + } + + content, err := renderTemplate("gateway-config.yaml.tmpl", gcb) + require.NoError(t, err) + + assert.NotContains(t, content, "http_services:") + }) +} diff --git a/internal/generate/helpers.go b/internal/generate/helpers.go index 771ae25..6dcc063 100644 --- a/internal/generate/helpers.go +++ b/internal/generate/helpers.go @@ -2,6 +2,8 @@ package generate import ( "fmt" + "net" + "net/url" "os" "path/filepath" "strings" @@ -53,25 +55,56 @@ func (g *Generator) hasHooks() bool { return false } -func (g *Generator) hasHomeOverride() bool { +func (g *Generator) hasRootHooks() bool { for _, f := range g.Features { - if f.HomeOverride != "" { + if len(f.RootHooks) > 0 { return true } } return false } -// hasMITMDomains returns true if any feature contributes MITM domains. -func (g *Generator) hasMITMDomains() bool { +// collectCapabilities gathers all capabilities from features, deduped. +func (g *Generator) collectCapabilities() []string { + var caps []string + seen := map[string]bool{} + for _, f := range g.Features { + for _, c := range f.Capabilities { + if !seen[c] { + seen[c] = true + caps = append(caps, c) + } + } + } + return caps +} + +// collectFeaturePorts gathers all port mappings from features. +func (g *Generator) collectFeaturePorts() []string { + var ports []string + for _, f := range g.Features { + ports = append(ports, f.Ports...) + } + return ports +} + +func (g *Generator) hasHomeOverride() bool { for _, f := range g.Features { - if len(f.MITMDomains) > 0 { + if f.HomeOverride != "" { return true } } return false } +// hasMITMDomains returns true if any feature contributes domains that require +// TLS interception (https:// or bare domains without scheme). HTTP-only domains +// are handled by the HTTP proxy and do not need MITM. +func (g *Generator) hasMITMDomains() bool { + mitmDomains, _ := g.splitDomainsByScheme() + return len(mitmDomains) > 0 +} + // collectMITMDomains gathers all MITM domains from features. func (g *Generator) collectMITMDomains() []string { var domains []string @@ -126,10 +159,20 @@ func (g *Generator) collectVolumePaths() []string { } // collectRewriters gathers all rewriter configs from features. +// Domains are normalized to strip the scheme but preserve host:port so that +// the gateway's AuthHeaderRewriter can do port-aware matching (two services +// on the same host with different ports get distinct rewriters). func (g *Generator) collectRewriters() []resolve.RewriterConfig { var rewriters []resolve.RewriterConfig for _, f := range g.Features { - rewriters = append(rewriters, f.Rewriters...) + for _, rw := range f.Rewriters { + normalized := make([]string, 0, len(rw.Domains)) + for _, d := range rw.Domains { + normalized = append(normalized, stripScheme(d)) + } + rw.Domains = normalized + rewriters = append(rewriters, rw) + } } return rewriters } @@ -177,6 +220,30 @@ func (g *Generator) collectHTTPPorts() []string { return ports } +// stripScheme removes the URL scheme from a domain string but preserves the +// host:port so that port-aware matching works correctly in the gateway. +// e.g. "http://host.internal:8000" -> "host.internal:8000" +// e.g. "https://api.github.com" -> "api.github.com" +func stripScheme(d string) string { + if strings.Contains(d, "://") { + parsed, err := url.Parse(d) + if err == nil { + return parsed.Host + } + } + return d +} + +// stripSchemeAndPort extracts the bare hostname from a domain string that may +// include a scheme and/or port (e.g. "http://host.internal:8000" -> "host.internal"). +func stripSchemeAndPort(d string) string { + d = stripScheme(d) + if h, _, err := net.SplitHostPort(d); err == nil { + return h + } + return d +} + // collectAgentEnv gathers agent-side environment variables from features. // These are dummy/non-secret values set in the agent container (e.g., GH_TOKEN=dummy). func (g *Generator) collectAgentEnv() []string { @@ -202,16 +269,16 @@ func (g *Generator) needsEntrypoint() bool { return true } for _, f := range g.Features { - if len(f.EntrypointHooks) > 0 || f.HomeOverride != "" { + if len(f.EntrypointHooks) > 0 || len(f.RootHooks) > 0 || f.HomeOverride != "" { return true } } return false } -// copyHooks copies entrypoint hook scripts to .build/hooks/. +// copyHooks copies entrypoint hook scripts and root hook scripts to .build/hooks/. func (g *Generator) copyHooks() error { - if !g.hasHooks() { + if !g.hasHooks() && !g.hasRootHooks() { return nil } @@ -232,6 +299,17 @@ func (g *Generator) copyHooks() error { return err } } + for _, hook := range f.RootHooks { + srcPath := filepath.Join(g.Dir, hook) + data, err := os.ReadFile(srcPath) + if err != nil { + return fmt.Errorf("reading root hook %s: %w", hook, err) + } + destPath := filepath.Join(hooksDir, filepath.Base(hook)) + if err := os.WriteFile(destPath, data, 0755); err != nil { + return err + } + } } return nil diff --git a/internal/generate/schema.go b/internal/generate/schema.go index 5b3812b..449bc20 100644 --- a/internal/generate/schema.go +++ b/internal/generate/schema.go @@ -3,6 +3,7 @@ package generate import ( "encoding/json" "fmt" + "maps" "os" "path/filepath" "reflect" @@ -143,9 +144,7 @@ func collectFeatureItemSchemas() []any { // Merge plugin-specific properties if pluginSchema != nil { if pluginProps, ok := pluginSchema["properties"].(map[string]any); ok { - for k, v := range pluginProps { - props[k] = v - } + maps.Copy(props, pluginProps) } // Carry over plugin-specific required fields if pluginRequired, ok := pluginSchema["required"].([]string); ok { @@ -155,8 +154,8 @@ func collectFeatureItemSchemas() []any { itemSchema := map[string]any{ "type": "object", - "properties": props, - "required": required, + "properties": props, + "required": required, "additionalProperties": false, } @@ -200,8 +199,7 @@ func structTypeToSchema(t reflect.Type) map[string]any { props := map[string]any{} var required []string - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) + for field := range t.Fields() { yamlTag := field.Tag.Get("yaml") if yamlTag == "" || yamlTag == "-" { continue diff --git a/internal/generate/templates.go b/internal/generate/templates.go index 622fa80..fc6b55a 100644 --- a/internal/generate/templates.go +++ b/internal/generate/templates.go @@ -4,6 +4,7 @@ import ( "bytes" "embed" "fmt" + "strings" "text/template" ) @@ -29,14 +30,14 @@ var templateFuncs = template.FuncMap{ return fmt.Sprintf("%q", s) }, "join": func(sep string, items []string) string { - result := "" + var result strings.Builder for i, item := range items { if i > 0 { - result += sep + result.WriteString(sep) } - result += item + result.WriteString(item) } - return result + return result.String() }, } diff --git a/internal/generate/templates/docker-compose.gateway.tmpl b/internal/generate/templates/docker-compose.gateway.tmpl index 3688b7a..3dda8c2 100644 --- a/internal/generate/templates/docker-compose.gateway.tmpl +++ b/internal/generate/templates/docker-compose.gateway.tmpl @@ -3,14 +3,19 @@ services: build: context: . dockerfile: Dockerfile.gateway + env_file: + - ../.env networks: internal: default: {{- range .ExternalNetworks }} {{ . }}: {{- end }} + cap_drop: + - ALL cap_add: - NET_ADMIN + - NET_BIND_SERVICE sysctls: - net.ipv4.ip_forward=1 environment: @@ -41,8 +46,24 @@ services: dockerfile: Dockerfile.agent networks: internal: + cap_drop: + - ALL cap_add: - NET_ADMIN + - DAC_OVERRIDE + - CHOWN + - SETUID + - SETGID + - FOWNER +{{- range .Capabilities }} + - {{ . }} +{{- end }} +{{- if .AgentPorts }} + ports: +{{- range .AgentPorts }} + - {{ quote . }} +{{- end }} +{{- end }} sysctls: - net.ipv4.conf.all.route_localnet=1 environment: @@ -68,6 +89,7 @@ services: restart: unless-stopped networks: + default: internal: internal: true {{- range .ExternalNetworks }} diff --git a/internal/generate/templates/docker-compose.single.tmpl b/internal/generate/templates/docker-compose.single.tmpl index 9178936..ac987ca 100644 --- a/internal/generate/templates/docker-compose.single.tmpl +++ b/internal/generate/templates/docker-compose.single.tmpl @@ -3,6 +3,16 @@ services: build: context: . dockerfile: Dockerfile + env_file: + - ../.env + cap_drop: + - ALL +{{- if .Capabilities }} + cap_add: +{{- range .Capabilities }} + - {{ . }} +{{- end }} +{{- end }} restart: unless-stopped {{- if .Ports }} ports: diff --git a/internal/generate/templates/entrypoint.agent.tmpl b/internal/generate/templates/entrypoint.agent.tmpl index d510f79..1c0439c 100644 --- a/internal/generate/templates/entrypoint.agent.tmpl +++ b/internal/generate/templates/entrypoint.agent.tmpl @@ -11,13 +11,18 @@ echo "entrypoint: gateway at $GATEWAY_IP" echo "entrypoint: switching DNS to gateway..." echo "nameserver $GATEWAY_IP" > /etc/resolv.conf +{{ if .HasMITM }} +echo "entrypoint: clearing /etc/hosts overrides for MITM..." +cp /etc/hosts /tmp/hosts.bak +grep -v 'host.containers.internal\|host.docker.internal' /tmp/hosts.bak > /tmp/hosts.new +cat /tmp/hosts.new > /etc/hosts +{{ end }} echo "entrypoint: setting default route via gateway..." ip route replace default via $GATEWAY_IP {{ range .Ports }} iptables -t nat -A PREROUTING -p tcp --dport {{ .ContainerPort }} -j DNAT --to-destination 127.0.0.1:{{ .ContainerPort }} {{ end }}{{ end }}{{ if .HasMITM }} -# Wait for sandbox CA certificate from gateway (shared volume) echo "entrypoint: waiting for sandbox CA certificate..." timeout=30 elapsed=0 @@ -38,5 +43,8 @@ if [ -d /opt/home-override ]; then cp -rT /opt/home-override /home/{{ .User }} chown -R {{ .User }}:{{ .User }} /home/{{ .User }} fi -{{ end }} +{{ end }}{{ if .HasRootHooks }} +echo "entrypoint: running root hooks..." +{{ range .RootHooks }}/opt/hooks/{{ . }} +{{ end }}{{ end }} exec su -c '{{ .UserCommand }}' {{ .User }} diff --git a/internal/generate/templates/gateway-config.yaml.tmpl b/internal/generate/templates/gateway-config.yaml.tmpl index 48ef27a..5e25a31 100644 --- a/internal/generate/templates/gateway-config.yaml.tmpl +++ b/internal/generate/templates/gateway-config.yaml.tmpl @@ -1,5 +1,6 @@ # Gateway configuration (auto-generated) listen: ":{{.ListenPort}}" +http_listen: ":{{.HTTPListenPort}}" dns_listen: ":{{.DNSPort}}" {{- if .MITMDomains}} mitm_domains: diff --git a/internal/plugins/external-services/plugin.go b/internal/plugins/external-services/plugin.go index 1459622..9b512f4 100644 --- a/internal/plugins/external-services/plugin.go +++ b/internal/plugins/external-services/plugin.go @@ -90,8 +90,26 @@ func init() { contrib.Rewriters = append(contrib.Rewriters, rewriters...) } + case "http": + if port == "" { + port = "80" + } + + contrib.HTTPServices = append(contrib.HTTPServices, resolve.HTTPService{ + Host: host, + Port: port, + }) + + if len(svc.Headers) > 0 { + rewriters, err := buildRewriters(host, svc.Headers) + if err != nil { + return nil, fmt.Errorf("external-services: service %q: %w", svc.URL, err) + } + contrib.Rewriters = append(contrib.Rewriters, rewriters...) + } + default: - return nil, fmt.Errorf("external-services: unsupported scheme %q in url %q (use docker:// or https://)", parsed.Scheme, svc.URL) + return nil, fmt.Errorf("external-services: unsupported scheme %q in url %q (use http://, https://, or docker://)", parsed.Scheme, svc.URL) } } diff --git a/internal/plugins/external-services/plugin_test.go b/internal/plugins/external-services/plugin_test.go index fdab434..e80ce56 100644 --- a/internal/plugins/external-services/plugin_test.go +++ b/internal/plugins/external-services/plugin_test.go @@ -139,7 +139,7 @@ func TestExternalServices_DockerMissingNetworkError(t *testing.T) { func TestExternalServices_UnsupportedSchemeError(t *testing.T) { _, err := resolve.ResolveFeature(".", "external-services", "external-services", map[string]any{ "services": []any{ - map[string]any{"url": "http://foo.com", "network": "n"}, + map[string]any{"url": "ftp://foo.com"}, }, }) assert.Error(t, err) diff --git a/internal/plugins/register.go b/internal/plugins/register.go index 5420805..6a2f5f7 100644 --- a/internal/plugins/register.go +++ b/internal/plugins/register.go @@ -10,6 +10,6 @@ import ( _ "github.com/donbader/agent-sandbox/internal/plugins/github-pat" _ "github.com/donbader/agent-sandbox/internal/plugins/mcp-oauth" _ "github.com/donbader/agent-sandbox/internal/plugins/pi" - + _ "github.com/donbader/agent-sandbox/internal/plugins/ssh" _ "github.com/donbader/agent-sandbox/internal/plugins/telegram" ) diff --git a/internal/plugins/ssh/plugin.go b/internal/plugins/ssh/plugin.go new file mode 100644 index 0000000..a3c8ada --- /dev/null +++ b/internal/plugins/ssh/plugin.go @@ -0,0 +1,151 @@ +// Package ssh implements the SSH feature plugin. +// It provides an SSH server inside the agent container for remote development access. +package ssh + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + + "github.com/donbader/agent-sandbox/internal/resolve" +) + +const defaultPort = 2222 +const defaultHostKeyPath = ".ssh_host_key" + +// Config defines the typed configuration for the ssh plugin. +type Config struct { + Port int `yaml:"port" schema:"SSH port inside the container" default:"2222" examples:"2222,22"` + AuthorizedKeys string `yaml:"authorized_keys" schema:"Path to public key file (relative to agent.yaml dir)" required:"true" examples:"./ssh_key.pub"` + HostKey string `yaml:"host_key" schema:"Path to persistent host private key (auto-generated if absent)" default:".ssh_host_key"` +} + +func generateHostKey(path string) error { + cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", path, "-N", "", "-C", "") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() +} + +func init() { + resolve.Register("ssh", func(projectDir string, cfg Config) (*resolve.FeatureContributions, error) { + if cfg.AuthorizedKeys == "" { + return nil, fmt.Errorf("ssh: missing required option 'authorized_keys'") + } + + port := cfg.Port + if port == 0 { + port = defaultPort + } + if port < 1 || port > 65535 { + return nil, fmt.Errorf("ssh: port must be between 1 and 65535, got %d", port) + } + + // Default host_key path if not specified + if cfg.HostKey == "" { + cfg.HostKey = defaultHostKeyPath + } + + // Validate the authorized_keys file exists at generate time. + keyPath := cfg.AuthorizedKeys + if !filepath.IsAbs(keyPath) { + keyPath = filepath.Join(projectDir, keyPath) + } + absKeyPath, err := filepath.Abs(keyPath) + if err != nil { + return nil, fmt.Errorf("ssh: resolving path %q: %w", cfg.AuthorizedKeys, err) + } + absProject, err := filepath.Abs(projectDir) + if err != nil { + return nil, fmt.Errorf("ssh: resolving project dir: %w", err) + } + if !strings.HasPrefix(absKeyPath, absProject+string(filepath.Separator)) && absKeyPath != absProject { + return nil, fmt.Errorf("ssh: path %q escapes project directory", cfg.AuthorizedKeys) + } + if _, err := os.Stat(keyPath); err != nil { + return nil, fmt.Errorf("ssh: reading authorized_keys file %q: %w", cfg.AuthorizedKeys, err) + } + + portStr := strconv.Itoa(port) + + scriptsDir := filepath.Join(projectDir, "scripts") + if err := os.MkdirAll(scriptsDir, 0o755); err != nil { + return nil, fmt.Errorf("ssh: creating scripts directory: %w", err) + } + + // Resolve and auto-generate host key if absent. + hostKeyPath := cfg.HostKey + if !filepath.IsAbs(hostKeyPath) { + hostKeyPath = filepath.Join(projectDir, hostKeyPath) + } + absHostKeyPath, err := filepath.Abs(hostKeyPath) + if err != nil { + return nil, fmt.Errorf("ssh: resolving path %q: %w", cfg.HostKey, err) + } + if !strings.HasPrefix(absHostKeyPath, absProject+string(filepath.Separator)) && absHostKeyPath != absProject { + return nil, fmt.Errorf("ssh: path %q escapes project directory", cfg.HostKey) + } + if _, err := os.Stat(hostKeyPath); err != nil { + if os.IsNotExist(err) { + if err := generateHostKey(hostKeyPath); err != nil { + return nil, fmt.Errorf("ssh: generating host key at %q: %w", cfg.HostKey, err) + } + } else { + return nil, fmt.Errorf("ssh: checking host key at %q: %w", cfg.HostKey, err) + } + } + + // The root hook script copies mounted key files into place. + // Keys are bind-mounted at /run/ssh/ from the host. + rootHook := fmt.Sprintf(`#!/bin/bash +set -e +cp /run/ssh/host_key /etc/ssh/ssh_host_ed25519_key +chmod 600 /etc/ssh/ssh_host_ed25519_key +ssh-keygen -y -f /etc/ssh/ssh_host_ed25519_key > /etc/ssh/ssh_host_ed25519_key.pub +mkdir -p /home/agent/.ssh +cp /run/ssh/authorized_keys /home/agent/.ssh/authorized_keys +chown -R agent:agent /home/agent/.ssh +/usr/sbin/sshd -p %s +`, portStr) + + rootHookPath := filepath.Join(scriptsDir, "ssh-root-setup.sh") + if err := os.WriteFile(rootHookPath, []byte(rootHook), 0o755); err != nil { + return nil, fmt.Errorf("ssh: writing root hook script: %w", err) + } + + permsHook := `#!/bin/bash +set -e +chmod 700 /home/agent/.ssh +chmod 600 /home/agent/.ssh/authorized_keys +` + permsHookPath := filepath.Join(scriptsDir, "ssh-perms.sh") + if err := os.WriteFile(permsHookPath, []byte(permsHook), 0o755); err != nil { + return nil, fmt.Errorf("ssh: writing entrypoint hook script: %w", err) + } + + portMapping := fmt.Sprintf("%s:%s", portStr, portStr) + + // Volume mounts: compose file lives in .build/, keys are in project root. + // Use relative path from .build/ back to project root. + hostKeyVolume := fmt.Sprintf("../%s:/run/ssh/host_key:ro", cfg.HostKey) + authKeysVolume := fmt.Sprintf("../%s:/run/ssh/authorized_keys:ro", cfg.AuthorizedKeys) + + return &resolve.FeatureContributions{ + Name: "ssh", + Commands: []string{ + "apt-get update && apt-get install -y --no-install-recommends openssh-server && rm -rf /var/lib/apt/lists/*", + "mkdir -p /run/sshd", + fmt.Sprintf("sed -i 's/^#*Port.*/Port %s/' /etc/ssh/sshd_config", portStr), + "sed -i 's/^#*PasswordAuthentication.*/PasswordAuthentication no/' /etc/ssh/sshd_config", + }, + RootHooks: []string{"scripts/ssh-root-setup.sh"}, + EntrypointHooks: []string{"scripts/ssh-perms.sh"}, + Volumes: []string{hostKeyVolume, authKeysVolume}, + Capabilities: []string{"SYS_CHROOT"}, + Ports: []string{portMapping}, + }, nil + }) +} diff --git a/internal/plugins/ssh/plugin_test.go b/internal/plugins/ssh/plugin_test.go new file mode 100644 index 0000000..190aa1f --- /dev/null +++ b/internal/plugins/ssh/plugin_test.go @@ -0,0 +1,153 @@ +package ssh + +import ( + "os" + "path/filepath" + "testing" + + "github.com/donbader/agent-sandbox/internal/resolve" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSSHPlugin_DefaultPort(t *testing.T) { + projectDir := t.TempDir() + pubkeyFile := filepath.Join(projectDir, "id_ed25519.pub") + require.NoError(t, os.WriteFile(pubkeyFile, []byte("ssh-ed25519 AAAAC3Nz testuser@host\n"), 0o644)) + + plugin := resolve.RegisteredPlugins()["ssh"] + require.NotNil(t, plugin, "ssh plugin not registered") + + contrib, err := plugin.Resolve(projectDir, map[string]any{ + "authorized_keys": "id_ed25519.pub", + }) + require.NoError(t, err) + + assert.Equal(t, "ssh", contrib.Name) + assert.Equal(t, []string{"2222:2222"}, contrib.Ports) + assert.Equal(t, []string{"SYS_CHROOT"}, contrib.Capabilities) + assert.Equal(t, []string{"scripts/ssh-root-setup.sh"}, contrib.RootHooks) + assert.Equal(t, []string{"scripts/ssh-perms.sh"}, contrib.EntrypointHooks) + + require.Len(t, contrib.Commands, 4) + assert.Contains(t, contrib.Commands[0], "openssh-server") + assert.Contains(t, contrib.Commands[2], "Port 2222") + assert.Contains(t, contrib.Commands[3], "PasswordAuthentication no") + + // Volume mounts for key files (bind mounts from project root via .build/) + assert.Contains(t, contrib.Volumes, "../.ssh_host_key:/run/ssh/host_key:ro") + assert.Contains(t, contrib.Volumes, "../id_ed25519.pub:/run/ssh/authorized_keys:ro") +} + +func TestSSHPlugin_CustomPort(t *testing.T) { + projectDir := t.TempDir() + pubkeyFile := filepath.Join(projectDir, "key.pub") + require.NoError(t, os.WriteFile(pubkeyFile, []byte("ssh-ed25519 AAAAC3Nz testuser@host\n"), 0o644)) + + plugin := resolve.RegisteredPlugins()["ssh"] + require.NotNil(t, plugin, "ssh plugin not registered") + + contrib, err := plugin.Resolve(projectDir, map[string]any{ + "authorized_keys": "key.pub", + "port": 8022, + }) + require.NoError(t, err) + + assert.Equal(t, []string{"8022:8022"}, contrib.Ports) + assert.Contains(t, contrib.Commands[2], "Port 8022") +} + +func TestSSHPlugin_WritesRootHookScript(t *testing.T) { + projectDir := t.TempDir() + pubkeyFile := filepath.Join(projectDir, "id_rsa.pub") + pubkey := "ssh-rsa AAAAB3NzaC1yc2EAAA testuser@host" + require.NoError(t, os.WriteFile(pubkeyFile, []byte(pubkey+"\n"), 0o644)) + + plugin := resolve.RegisteredPlugins()["ssh"] + require.NotNil(t, plugin, "ssh plugin not registered") + + _, err := plugin.Resolve(projectDir, map[string]any{ + "authorized_keys": "id_rsa.pub", + }) + require.NoError(t, err) + + rootHookPath := filepath.Join(projectDir, "scripts", "ssh-root-setup.sh") + content, err := os.ReadFile(rootHookPath) + require.NoError(t, err) + + script := string(content) + + // Script references the mounted paths, not embedded key material + assert.Contains(t, script, "cp /run/ssh/host_key /etc/ssh/ssh_host_ed25519_key") + assert.Contains(t, script, "cp /run/ssh/authorized_keys /home/agent/.ssh/authorized_keys") + assert.Contains(t, script, "/usr/sbin/sshd -p 2222") + assert.Contains(t, script, "chown -R agent:agent /home/agent/.ssh") + assert.Contains(t, script, "chmod 600 /etc/ssh/ssh_host_ed25519_key") + assert.Contains(t, script, "ssh-keygen -y -f /etc/ssh/ssh_host_ed25519_key") + + // Must NOT contain any actual key material + assert.NotContains(t, script, pubkey) + assert.NotContains(t, script, "HOSTKEY") + assert.NotContains(t, script, "PUBKEY") +} + +func TestSSHPlugin_WritesPermsHookScript(t *testing.T) { + projectDir := t.TempDir() + pubkeyFile := filepath.Join(projectDir, "key.pub") + require.NoError(t, os.WriteFile(pubkeyFile, []byte("ssh-ed25519 AAAAC3Nz testuser@host\n"), 0o644)) + + plugin := resolve.RegisteredPlugins()["ssh"] + require.NotNil(t, plugin, "ssh plugin not registered") + + _, err := plugin.Resolve(projectDir, map[string]any{ + "authorized_keys": "key.pub", + }) + require.NoError(t, err) + + permsHookPath := filepath.Join(projectDir, "scripts", "ssh-perms.sh") + content, err := os.ReadFile(permsHookPath) + require.NoError(t, err) + + assert.Contains(t, string(content), "chmod 700 /home/agent/.ssh") + assert.Contains(t, string(content), "chmod 600 /home/agent/.ssh/authorized_keys") +} + +func TestSSHPlugin_ErrorsWithoutAuthorizedKeys(t *testing.T) { + plugin := resolve.RegisteredPlugins()["ssh"] + require.NotNil(t, plugin, "ssh plugin not registered") + + _, err := plugin.Resolve(t.TempDir(), map[string]any{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing required option 'authorized_keys'") +} + +func TestSSHPlugin_ErrorsWhenKeyFileNotFound(t *testing.T) { + plugin := resolve.RegisteredPlugins()["ssh"] + require.NotNil(t, plugin, "ssh plugin not registered") + + _, err := plugin.Resolve(t.TempDir(), map[string]any{ + "authorized_keys": "nonexistent.pub", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "reading authorized_keys file") +} + +func TestSSHPlugin_VolumeMountsUseRelativePaths(t *testing.T) { + projectDir := t.TempDir() + pubkeyFile := filepath.Join(projectDir, "ssh_key.pub") + require.NoError(t, os.WriteFile(pubkeyFile, []byte("ssh-ed25519 AAAAC3Nz testuser@host\n"), 0o644)) + + plugin := resolve.RegisteredPlugins()["ssh"] + require.NotNil(t, plugin, "ssh plugin not registered") + + contrib, err := plugin.Resolve(projectDir, map[string]any{ + "authorized_keys": "ssh_key.pub", + "host_key": "my_host_key", + }) + require.NoError(t, err) + + // Volumes should use relative paths from .build/ to project root + require.Len(t, contrib.Volumes, 2) + assert.Equal(t, "../my_host_key:/run/ssh/host_key:ro", contrib.Volumes[0]) + assert.Equal(t, "../ssh_key.pub:/run/ssh/authorized_keys:ro", contrib.Volumes[1]) +} diff --git a/internal/resolve/plugin.go b/internal/resolve/plugin.go index a5add68..bb78996 100644 --- a/internal/resolve/plugin.go +++ b/internal/resolve/plugin.go @@ -73,19 +73,22 @@ type HTTPService struct { // FeatureContributions holds what a feature adds to the build. type FeatureContributions struct { - Name string // plugin name (for diagnostics and logging) - Commands []string // RUN commands for Dockerfile - EntrypointHooks []string // scripts to run on container start (source paths) - Volumes []string // named volumes (e.g., "name:/path") - HomeOverride string // directory to copy into home on start - MITMDomains []string // domains the gateway should MITM (terminate TLS) - ChannelName string // channel type (e.g., "telegram") - AgentEnv []string // environment variables for agent container (dummy values, not secrets) + Name string // plugin name (for diagnostics and logging) + Commands []string // RUN commands for Dockerfile + EntrypointHooks []string // scripts to run on container start (source paths) + RootHooks []string // scripts to run as root before dropping to agent user (source paths) + Volumes []string // named volumes (e.g., "name:/path") + HomeOverride string // directory to copy into home on start + MITMDomains []string // domains the gateway should MITM (terminate TLS) + ChannelName string // channel type (e.g., "telegram") + AgentEnv []string // environment variables for agent container (dummy values, not secrets) ChannelConfig map[string]any // plugin-specific config passed to channel-manager-config.json - Rewriters []RewriterConfig // gateway rewriters to instantiate for this feature - CommandPluginDir string // path to TypeScript command plugin source (copied into channel-manager) - ExternalNetworks []string // external Docker networks the gateway should join - HTTPServices []HTTPService // plain HTTP services to proxy with header injection + Rewriters []RewriterConfig // gateway rewriters to instantiate for this feature + CommandPluginDir string // path to TypeScript command plugin source (copied into channel-manager) + ExternalNetworks []string // external Docker networks the gateway should join + HTTPServices []HTTPService // plain HTTP services to proxy with header injection + Capabilities []string // additional Linux capabilities for the agent container (e.g., "SYS_CHROOT") + Ports []string // host:container port mappings to expose (e.g., "2222:2222") } // registry holds registered feature plugins. diff --git a/internal/runtime/detect.go b/internal/runtime/detect.go new file mode 100644 index 0000000..d5ff872 --- /dev/null +++ b/internal/runtime/detect.go @@ -0,0 +1,99 @@ +// Package runtime detects the container runtime available on the host. +package runtime + +import ( + "fmt" + "os/exec" +) + +// Runtime identifies a container runtime engine. +type Runtime string + +const ( + Docker Runtime = "docker" + Podman Runtime = "podman" +) + +// Detected holds the result of container runtime detection. +type Detected struct { + Runtime Runtime + Binary string + ComposeCmd []string +} + +// DetectWithOverride returns the detected runtime. If override is "docker" or +// "podman", that value is used directly (after verifying the binary exists on +// PATH). Otherwise it falls back to PATH auto-detection (podman preferred over +// docker). Returns an error if no runtime is found. +func DetectWithOverride(override string) (*Detected, error) { + if override != "" { + return resolveOverride(override) + } + return detectFromPath() +} + +// DetectOrDefault returns the detected runtime via PATH auto-detection, +// falling back to Docker defaults if detection fails. +func DetectOrDefault() *Detected { + return DetectOrDefaultWithOverride("") +} + +// DetectOrDefaultWithOverride is like DetectOrDefault but accepts an override +// value that takes precedence over PATH detection. +func DetectOrDefaultWithOverride(override string) *Detected { + d, err := DetectWithOverride(override) + if err != nil { + return &Detected{ + Runtime: Docker, + Binary: "docker", + ComposeCmd: []string{"docker", "compose"}, + } + } + return d +} + +func resolveOverride(val string) (*Detected, error) { + switch Runtime(val) { + case Docker: + if _, err := exec.LookPath("docker"); err != nil { + return nil, fmt.Errorf("container_runtime set to %q but binary not found on PATH", val) + } + return buildDetected(Docker), nil + case Podman: + if _, err := exec.LookPath("podman"); err != nil { + return nil, fmt.Errorf("container_runtime set to %q but binary not found on PATH", val) + } + return buildDetected(Podman), nil + default: + return nil, fmt.Errorf("unsupported container_runtime value %q: must be \"docker\" or \"podman\"", val) + } +} + +func detectFromPath() (*Detected, error) { + if _, err := exec.LookPath("podman"); err == nil { + return buildDetected(Podman), nil + } + if _, err := exec.LookPath("docker"); err == nil { + return buildDetected(Docker), nil + } + return nil, fmt.Errorf("no container runtime found: install docker or podman and ensure it is on PATH") +} + +func buildDetected(rt Runtime) *Detected { + binary := string(rt) + composeCmd := []string{binary, "compose"} + + if rt == Podman { + if err := exec.Command("podman", "compose", "version").Run(); err != nil { + if _, err2 := exec.LookPath("podman-compose"); err2 == nil { + composeCmd = []string{"podman-compose"} + } + } + } + + return &Detected{ + Runtime: rt, + Binary: binary, + ComposeCmd: composeCmd, + } +} diff --git a/internal/runtime/detect_test.go b/internal/runtime/detect_test.go new file mode 100644 index 0000000..4f01036 --- /dev/null +++ b/internal/runtime/detect_test.go @@ -0,0 +1,100 @@ +package runtime + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDetectWithOverride_Docker(t *testing.T) { + d, err := DetectWithOverride("docker") + if err != nil { + t.Skipf("docker not on PATH: %v", err) + } + + assert.Equal(t, Docker, d.Runtime) + assert.Equal(t, "docker", d.Binary) + assert.Equal(t, []string{"docker", "compose"}, d.ComposeCmd) +} + +func TestDetectWithOverride_Podman(t *testing.T) { + d, err := DetectWithOverride("podman") + if err != nil { + t.Skipf("podman not on PATH: %v", err) + } + + assert.Equal(t, Podman, d.Runtime) + assert.Equal(t, "podman", d.Binary) +} + +func TestDetectWithOverride_InvalidReturnsError(t *testing.T) { + _, err := DetectWithOverride("containerd") + + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported container_runtime value") + assert.Contains(t, err.Error(), "containerd") +} + +func TestDetectWithOverride_EmptyFallsToPath(t *testing.T) { + d, err := DetectWithOverride("") + if err != nil { + t.Skipf("no container runtime on PATH: %v", err) + } + + // Should be one of the two valid runtimes detected from PATH + assert.Contains(t, []Runtime{Docker, Podman}, d.Runtime) + assert.Equal(t, string(d.Runtime), d.Binary) +} + +func TestDetectWithOverride_IgnoresEnvVar(t *testing.T) { + // Env var should have no effect — only the override param matters + t.Setenv("CONTAINER_RUNTIME", "podman") + + d, err := DetectWithOverride("docker") + if err != nil { + t.Skipf("docker not on PATH: %v", err) + } + + assert.Equal(t, Docker, d.Runtime) +} + +func TestDetectOrDefault_ReturnsDockerFallback(t *testing.T) { + // DetectOrDefault calls DetectWithOverride("") which does PATH detection. + // If PATH has no runtime, it returns Docker defaults. + d := DetectOrDefault() + + // Should always succeed — either finds a runtime or falls back to Docker + assert.Contains(t, []Runtime{Docker, Podman}, d.Runtime) + assert.Equal(t, string(d.Runtime), d.Binary) +} + +func TestDetectOrDefaultWithOverride_UsesOverride(t *testing.T) { + d := DetectOrDefaultWithOverride("docker") + // If docker isn't on PATH, it falls back to default (which is also docker) + assert.Equal(t, Docker, d.Runtime) + assert.Equal(t, "docker", d.Binary) +} + +func TestDetectOrDefaultWithOverride_InvalidFallsToDefault(t *testing.T) { + d := DetectOrDefaultWithOverride("bogus") + + assert.Equal(t, Docker, d.Runtime) + assert.Equal(t, "docker", d.Binary) + assert.Equal(t, []string{"docker", "compose"}, d.ComposeCmd) +} + +func TestBuildDetected_ComposeCmdDocker(t *testing.T) { + d := buildDetected(Docker) + + assert.Equal(t, []string{"docker", "compose"}, d.ComposeCmd) +} + +func TestBuildDetected_ComposeCmdPodman(t *testing.T) { + d := buildDetected(Podman) + + // Podman compose command depends on what's installed; + // verify the binary field is correct regardless. + assert.Equal(t, "podman", d.Binary) + assert.Equal(t, Podman, d.Runtime) +} diff --git a/tests/integration/build_test.go b/tests/integration/build_test.go index 708435c..fecd841 100644 --- a/tests/integration/build_test.go +++ b/tests/integration/build_test.go @@ -10,12 +10,14 @@ import ( "github.com/donbader/agent-sandbox/internal/config" "github.com/donbader/agent-sandbox/internal/generate" + "github.com/donbader/agent-sandbox/internal/runtime" "github.com/donbader/agent-sandbox/plugins/codex" "github.com/stretchr/testify/require" ) func TestCodexImage_Builds(t *testing.T) { outDir := t.TempDir() + rt := runtime.DetectOrDefault() g := &generate.Generator{ Config: &config.AgentConfig{ @@ -29,27 +31,28 @@ func TestCodexImage_Builds(t *testing.T) { require.NoError(t, g.Run()) - // Verify docker build succeeds - cmd := exec.Command("docker", "build", "-t", "agent-sandbox-test-codex", outDir) + // Verify container build succeeds + cmd := exec.Command(rt.Binary, "build", "-t", "agent-sandbox-test-codex", outDir) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr err := cmd.Run() - require.NoError(t, err, "docker build failed") + require.NoError(t, err, "container build failed") // Cleanup image t.Cleanup(func() { - cleanup := exec.Command("docker", "rmi", "agent-sandbox-test-codex") + cleanup := exec.Command(rt.Binary, "rmi", "agent-sandbox-test-codex") _ = cleanup.Run() }) // Verify codex is installed - out, err := exec.Command("docker", "run", "--rm", "agent-sandbox-test-codex", "codex", "--version").CombinedOutput() + out, err := exec.Command(rt.Binary, "run", "--rm", "agent-sandbox-test-codex", "codex", "--version").CombinedOutput() require.NoError(t, err, "codex --version failed: %s", string(out)) t.Logf("codex version: %s", string(out)) } func TestCodexImage_AgentUser(t *testing.T) { outDir := t.TempDir() + rt := runtime.DetectOrDefault() g := &generate.Generator{ Config: &config.AgentConfig{ @@ -63,21 +66,21 @@ func TestCodexImage_AgentUser(t *testing.T) { require.NoError(t, g.Run()) - cmd := exec.Command("docker", "build", "-t", "agent-sandbox-test-user", outDir) + cmd := exec.Command(rt.Binary, "build", "-t", "agent-sandbox-test-user", outDir) require.NoError(t, cmd.Run()) t.Cleanup(func() { - cleanup := exec.Command("docker", "rmi", "agent-sandbox-test-user") + cleanup := exec.Command(rt.Binary, "rmi", "agent-sandbox-test-user") _ = cleanup.Run() }) // Verify runs as agent user - out, err := exec.Command("docker", "run", "--rm", "agent-sandbox-test-user", "whoami").CombinedOutput() + out, err := exec.Command(rt.Binary, "run", "--rm", "agent-sandbox-test-user", "whoami").CombinedOutput() require.NoError(t, err) require.Contains(t, string(out), "agent") // Verify workdir is /home/agent - out, err = exec.Command("docker", "run", "--rm", "agent-sandbox-test-user", "pwd").CombinedOutput() + out, err = exec.Command(rt.Binary, "run", "--rm", "agent-sandbox-test-user", "pwd").CombinedOutput() require.NoError(t, err) require.Contains(t, string(out), "/home/agent") }