diff --git a/README.md b/README.md index b55554c..7174e56 100644 --- a/README.md +++ b/README.md @@ -107,8 +107,19 @@ agent-scanner inspect --print-full-descriptions Show full entity descriptions --analysis-url URL Remote verification server URL --control-server URL Upload results to control server +--dangerously-run-mcp-servers Start every stdio MCP server without consent prompts +--ci Exit non-zero on findings/failures (requires --dangerously-run-mcp-servers) +--ignore-issues-codes CODES Comma-separated issue codes to ignore for the CI exit (only with --ci) ``` +By default, `scan` and `inspect` prompt before launching each stdio MCP server as +a subprocess (the command and redacted env are shown). Pass +`--dangerously-run-mcp-servers` to start them all without prompting — required for +non-interactive `--ci` runs. When `scan` is given a control server with an +`x-client-id` header, the run is treated as automated and prompts are skipped +(`inspect` has no control servers, so it always prompts unless +`--dangerously-run-mcp-servers` is set). + ### JSON output ```bash diff --git a/internal/cli/consent.go b/internal/cli/consent.go new file mode 100644 index 0000000..a591c88 --- /dev/null +++ b/internal/cli/consent.go @@ -0,0 +1,187 @@ +package cli + +import ( + "context" + "errors" + "fmt" + "os" + "sort" + "strings" + "time" + + "github.com/go-authgate/agent-scanner/internal/consent" + "github.com/go-authgate/agent-scanner/internal/models" +) + +// scanRunTimeout bounds a non-interactive scan/inspect run. +const scanRunTimeout = 5 * time.Minute + +// scanContext derives the run context. When bounded (no consent prompt will +// block on human input — e.g. a push key marks an automated run, or +// --dangerously-run-mcp-servers skips prompts), it applies scanRunTimeout; +// otherwise it returns the parent unbounded so interactive consent prompts are +// not raced by the deadline. Per-server connection timeouts still bound each +// server either way. +func scanContext(parent context.Context, bounded bool) (context.Context, context.CancelFunc) { + if bounded { + return context.WithTimeout(parent, scanRunTimeout) + } + return context.WithCancel(parent) +} + +// consentExitError carries a process exit code out of a command so the caller +// can exit with the exact code interactive/CI runs require. +type consentExitError struct { + code int + msg string +} + +func (e *consentExitError) Error() string { return e.msg } + +// exitOnConsentError prints the message and exits with the carried code when err +// is a consentExitError; otherwise it does nothing. +func exitOnConsentError(err error) { + var ce *consentExitError + if errors.As(err, &ce) { + fmt.Fprintln(os.Stderr, ce.msg) + os.Exit(ce.code) + } +} + +// validateConsentFlags enforces the CI flag relationships: +// - --ci requires --dangerously-run-mcp-servers (CI cannot answer prompts) +// - --ignore-issues-codes is only valid together with --ci +// +// On violation it returns a consentExitError with exit code 2. +func validateConsentFlags() error { + if commonFlags.CI && !commonFlags.DangerouslyRun { + return &consentExitError{ + code: 2, + msg: "running with --ci requires --dangerously-run-mcp-servers; " + + "CI runs start a subprocess for every stdio MCP server, so trust must be confirmed explicitly", + } + } + if commonFlags.IgnoreIssueCodes != "" && !commonFlags.CI { + return &consentExitError{code: 2, msg: "--ignore-issues-codes can only be used with --ci"} + } + return nil +} + +// hasPushKey reports whether any configured control server carries an +// x-client-id header, which marks an automated (non-interactive) run that skips +// consent prompts. Only headers index-aligned with an actual --control-server +// count (matching parseControlServers); stray --control-server-H values do not. +func hasPushKey() bool { + for _, cs := range parseControlServers() { + for header := range cs.Headers { + if strings.EqualFold(header, "x-client-id") { + return true + } + } + } + return false +} + +// buildConsentFn returns the pipeline consent hook, printing the appropriate +// banner. It returns nil (run every server, no gating) when prompts are skipped: +// when --dangerously-run-mcp-servers is set, or when the run is non-interactive +// (a push key is configured). interactive=true with no dangerous flag yields the +// interactive prompt collector. +func buildConsentFn(interactive bool) func([]*models.ClientToInspect) map[models.ServerRef]bool { + if commonFlags.DangerouslyRun { + if interactive { + fmt.Fprint(os.Stderr, + "--dangerously-run-mcp-servers is set: starting every stdio MCP server "+ + "in the scanned configs without prompting.\n\n") + } + return nil + } + if !interactive { + return nil + } + return func(clients []*models.ClientToInspect) map[models.ServerRef]bool { + return consent.CollectConsent(clients, os.Stderr, os.Stdin) + } +} + +// parseIgnoreCodes splits --ignore-issues-codes into a set. +func parseIgnoreCodes() map[string]bool { + codes := map[string]bool{} + for c := range strings.SplitSeq(commonFlags.IgnoreIssueCodes, ",") { + if c = strings.TrimSpace(c); c != "" { + codes[c] = true + } + } + return codes +} + +// applyIgnoreCodes removes issue findings whose code is in --ignore-issues-codes +// so they are not printed. It filters only Issues; runtime-failure X-codes +// (surfaced via result/server errors) are not filtered here and may still be +// shown under --print-errors. No-op outside CI mode; call before formatting. +func applyIgnoreCodes(results []models.ScanPathResult) { + if !commonFlags.CI { + return + } + ignore := parseIgnoreCodes() + if len(ignore) == 0 { + return + } + for i := range results { + kept := results[i].Issues[:0] + for _, issue := range results[i].Issues { + if !ignore[issue.Code] { + kept = append(kept, issue) + } + } + results[i].Issues = kept + } +} + +// ciExitError returns a consentExitError (exit code 1) if, in CI mode, any issue +// or runtime-failure code remains after --ignore-issues-codes. It does not +// mutate results; call applyIgnoreCodes first to filter the printed output. +func ciExitError(results []models.ScanPathResult) error { + if !commonFlags.CI { + return nil + } + ignore := parseIgnoreCodes() + + remaining := map[string]bool{} + for i := range results { + for _, issue := range results[i].Issues { + if issue.Code != "" && !ignore[issue.Code] { + remaining[issue.Code] = true + } + } + // Runtime failures surface as X-codes on the path and its servers. + if results[i].Error != nil && results[i].Error.IsFailure { + if code := models.ErrorToIssueCode(results[i].Error.Category); !ignore[code] { + remaining[code] = true + } + } + for _, s := range results[i].Servers { + if s.Error != nil && s.Error.IsFailure { + if code := models.ErrorToIssueCode(s.Error.Category); !ignore[code] { + remaining[code] = true + } + } + } + } + + if len(remaining) == 0 { + return nil + } + codes := make([]string, 0, len(remaining)) + for c := range remaining { + codes = append(codes, c) + } + sort.Strings(codes) + return &consentExitError{ + code: 1, + msg: fmt.Sprintf( + "CI (--ci): exiting with code 1 (issue codes: %s)", + strings.Join(codes, ", "), + ), + } +} diff --git a/internal/cli/consent_test.go b/internal/cli/consent_test.go new file mode 100644 index 0000000..466c102 --- /dev/null +++ b/internal/cli/consent_test.go @@ -0,0 +1,199 @@ +package cli + +import ( + "errors" + "testing" + + "github.com/go-authgate/agent-scanner/internal/models" +) + +// resetFlags clears the package-global flag state mutated by these tests. +func resetFlags(t *testing.T) { + t.Helper() + commonFlags = CommonFlags{} + scanFlags = ScanFlags{} +} + +func TestValidateConsentFlags(t *testing.T) { + tests := []struct { + name string + ci bool + dangerous bool + ignore string + wantCode int // 0 = no error + }{ + {"no ci, no flags", false, false, "", 0}, + {"ci with dangerous", true, true, "", 0}, + {"ci without dangerous", true, false, "", 2}, + {"ignore codes without ci", false, false, "W001", 2}, + {"ignore codes with ci+dangerous", true, true, "W001", 0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetFlags(t) + commonFlags.CI = tt.ci + commonFlags.DangerouslyRun = tt.dangerous + commonFlags.IgnoreIssueCodes = tt.ignore + + err := validateConsentFlags() + if tt.wantCode == 0 { + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + return + } + var ce *consentExitError + if !errors.As(err, &ce) { + t.Fatalf("expected consentExitError, got %v", err) + } + if ce.code != tt.wantCode { + t.Errorf("exit code = %d, want %d", ce.code, tt.wantCode) + } + }) + } +} + +func TestHasPushKey(t *testing.T) { + tests := []struct { + name string + servers []string + headers []string + want bool + }{ + {"none", nil, nil, false}, + {"unrelated header", []string{"https://cs"}, []string{"Authorization: Bearer x"}, false}, + {"x-client-id present", []string{"https://cs"}, []string{"x-client-id: abc123"}, true}, + {"case insensitive", []string{"https://cs"}, []string{"X-Client-Id: abc"}, true}, + { + "prefix must not match", + []string{"https://cs"}, + []string{"X-Client-Identity: abc"}, + false, + }, + // A stray header with no matching --control-server must not count. + {"unattached header ignored", nil, []string{"x-client-id: abc"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetFlags(t) + scanFlags.ControlServers = tt.servers + scanFlags.ControlHeaders = tt.headers + if got := hasPushKey(); got != tt.want { + t.Errorf("hasPushKey = %v, want %v", got, tt.want) + } + }) + } +} + +func TestBuildConsentFn(t *testing.T) { + t.Run("dangerous skips gating", func(t *testing.T) { + resetFlags(t) + commonFlags.DangerouslyRun = true + if buildConsentFn(true) != nil { + t.Error("dangerous run should return nil consent fn") + } + }) + t.Run("non-interactive skips gating", func(t *testing.T) { + resetFlags(t) + if buildConsentFn(false) != nil { + t.Error("non-interactive run should return nil consent fn") + } + }) + t.Run("interactive returns collector", func(t *testing.T) { + resetFlags(t) + if buildConsentFn(true) == nil { + t.Error("interactive run should return a consent fn") + } + }) +} + +func issueResult(codes ...string) models.ScanPathResult { + var issues []models.Issue + for _, c := range codes { + issues = append(issues, models.Issue{Code: c}) + } + return models.ScanPathResult{Issues: issues} +} + +func TestCIExitError(t *testing.T) { + t.Run("ci off returns nil", func(t *testing.T) { + resetFlags(t) + if err := ciExitError([]models.ScanPathResult{issueResult("E001")}); err != nil { + t.Errorf("expected nil when CI off, got %v", err) + } + }) + + t.Run("ignored W001 does not fail, E-code does", func(t *testing.T) { + resetFlags(t) + commonFlags.CI = true + commonFlags.IgnoreIssueCodes = "W001" + results := []models.ScanPathResult{issueResult("W001", "E001")} + + err := ciExitError(results) + var ce *consentExitError + if !errors.As(err, &ce) || ce.code != 1 { + t.Fatalf("expected exit code 1, got %v", err) + } + // ciExitError must NOT mutate results; filtering is applyIgnoreCodes' job. + if len(results[0].Issues) != 2 { + t.Errorf("ciExitError must not mutate results, got %+v", results[0].Issues) + } + }) + + t.Run("only ignored issue → no failure", func(t *testing.T) { + resetFlags(t) + commonFlags.CI = true + commonFlags.IgnoreIssueCodes = "W001" + results := []models.ScanPathResult{issueResult("W001")} + if err := ciExitError(results); err != nil { + t.Errorf("expected nil (W001 ignored), got %v", err) + } + }) + + t.Run("runtime failure code fails CI", func(t *testing.T) { + resetFlags(t) + commonFlags.CI = true + results := []models.ScanPathResult{{ + Servers: []models.ServerScanResult{{ + Error: models.NewScanError("boom", models.ErrCatServerStartup, true), + }}, + }} + err := ciExitError(results) + var ce *consentExitError + if !errors.As(err, &ce) || ce.code != 1 { + t.Fatalf("expected exit code 1 for runtime failure, got %v", err) + } + }) + + t.Run("clean scan → nil", func(t *testing.T) { + resetFlags(t) + commonFlags.CI = true + if err := ciExitError([]models.ScanPathResult{issueResult()}); err != nil { + t.Errorf("expected nil for clean scan, got %v", err) + } + }) +} + +func TestApplyIgnoreCodes(t *testing.T) { + t.Run("removes ignored issues in CI mode", func(t *testing.T) { + resetFlags(t) + commonFlags.CI = true + commonFlags.IgnoreIssueCodes = "W001,W002" + results := []models.ScanPathResult{issueResult("W001", "E001", "W002")} + + applyIgnoreCodes(results) + if len(results[0].Issues) != 1 || results[0].Issues[0].Code != "E001" { + t.Errorf("expected only E001 to remain, got %+v", results[0].Issues) + } + }) + + t.Run("no-op outside CI mode", func(t *testing.T) { + resetFlags(t) + commonFlags.IgnoreIssueCodes = "W001" + results := []models.ScanPathResult{issueResult("W001")} + applyIgnoreCodes(results) + if len(results[0].Issues) != 1 { + t.Errorf("expected no filtering when CI off, got %+v", results[0].Issues) + } + }) +} diff --git a/internal/cli/flags.go b/internal/cli/flags.go index 8516a19..d430ff6 100644 --- a/internal/cli/flags.go +++ b/internal/cli/flags.go @@ -16,6 +16,9 @@ type CommonFlags struct { ScanAllUsers bool ServerTimeout int SuppressServerIO bool + DangerouslyRun bool + CI bool + IgnoreIssueCodes string } // ScanFlags holds scan-specific flags. @@ -60,6 +63,12 @@ func addCommonFlags(cmd *cobra.Command) { IntVar(&commonFlags.ServerTimeout, "server-timeout", 10, "MCP server connection timeout in seconds") cmd.Flags(). BoolVar(&commonFlags.SuppressServerIO, "suppress-mcpserver-io", true, "Suppress MCP server stdout/stderr") + cmd.Flags(). + BoolVar(&commonFlags.DangerouslyRun, "dangerously-run-mcp-servers", false, "Start every stdio MCP server without per-server consent prompts") + cmd.Flags(). + BoolVar(&commonFlags.CI, "ci", false, "Exit non-zero on findings or runtime failures (requires --dangerously-run-mcp-servers)") + cmd.Flags(). + StringVar(&commonFlags.IgnoreIssueCodes, "ignore-issues-codes", "", "Comma-separated issue codes to ignore for CI exit (only valid with --ci)") } // addScanFlags registers scan-specific flags. diff --git a/internal/cli/inspect.go b/internal/cli/inspect.go index 95e32f0..5f76b8a 100644 --- a/internal/cli/inspect.go +++ b/internal/cli/inspect.go @@ -1,10 +1,8 @@ package cli import ( - "context" "fmt" "os" - "time" "github.com/go-authgate/agent-scanner/internal/discovery" "github.com/go-authgate/agent-scanner/internal/inspect" @@ -29,11 +27,19 @@ func newInspectCmd() *cobra.Command { func runInspect(cmd *cobra.Command, args []string) error { setupLogging() + if err := validateConsentFlags(); err != nil { + exitOnConsentError(err) + return err + } + if !commonFlags.JSON { fmt.Fprintf(os.Stderr, "Agent Scanner v%s (inspect mode)\n\n", version.Version) } - ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Minute) + // inspect is always an interactive, manual invocation; skip the overall + // timeout when consent prompts are active so a slow answer can't cancel it. + consentFn := buildConsentFn(true) + ctx, cancel := scanContext(cmd.Context(), consentFn == nil) defer cancel() discoverer := discovery.NewDiscoverer() @@ -48,6 +54,7 @@ func runInspect(cmd *cobra.Command, args []string) error { ScanAllUsers: commonFlags.ScanAllUsers, InspectOnly: true, Verbose: commonFlags.Verbose, + ConsentFn: consentFn, }) results, err := p.Run(ctx) @@ -55,10 +62,20 @@ func runInspect(cmd *cobra.Command, args []string) error { return fmt.Errorf("inspect failed: %w", err) } + applyIgnoreCodes(results) + formatter := selectFormatter() - return formatter.FormatResults(results, output.FormatOptions{ + if err := formatter.FormatResults(results, output.FormatOptions{ PrintErrors: commonFlags.PrintErrors, PrintFullDescs: commonFlags.PrintFullDescs, InspectMode: true, - }) + }); err != nil { + return err + } + + if err := ciExitError(results); err != nil { + exitOnConsentError(err) + return err + } + return nil } diff --git a/internal/cli/scan.go b/internal/cli/scan.go index 996abba..9990bb2 100644 --- a/internal/cli/scan.go +++ b/internal/cli/scan.go @@ -1,12 +1,10 @@ package cli import ( - "context" "fmt" "log/slog" "os" "strings" - "time" "github.com/go-authgate/agent-scanner/internal/analysis" "github.com/go-authgate/agent-scanner/internal/discovery" @@ -34,9 +32,20 @@ func newScanCmd() *cobra.Command { func runScan(cmd *cobra.Command, args []string) error { setupLogging() + + // Validate consent/CI flag relationships before doing any work. + if err := validateConsentFlags(); err != nil { + exitOnConsentError(err) + return err + } + printBanner() - ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Minute) + // Interactive consent prompts can take an arbitrary amount of human time, so + // only impose the overall timeout on non-interactive runs (where prompts are + // skipped). Per-server connection timeouts still bound each server. + consentFn := buildConsentFn(!hasPushKey()) + ctx, cancel := scanContext(cmd.Context(), consentFn == nil) defer cancel() // Build pipeline components @@ -64,6 +73,9 @@ func runScan(cmd *cobra.Command, args []string) error { SkipSSLVerify: commonFlags.SkipSSLVerify, ChecksPerServer: scanFlags.ChecksPerServer, Verbose: commonFlags.Verbose, + // Interactive consent gates stdio startup unless a push key marks an + // automated run; --dangerously-run-mcp-servers skips prompts entirely. + ConsentFn: consentFn, }) results, err := p.Run(ctx) @@ -71,13 +83,26 @@ func runScan(cmd *cobra.Command, args []string) error { return fmt.Errorf("scan failed: %w", err) } + // In CI mode, drop ignored issue findings before formatting so they are not + // printed. (Runtime-failure X-codes are not filtered here.) + applyIgnoreCodes(results) + // Format output formatter := selectFormatter() - return formatter.FormatResults(results, output.FormatOptions{ + if err := formatter.FormatResults(results, output.FormatOptions{ PrintErrors: commonFlags.PrintErrors, PrintFullDescs: commonFlags.PrintFullDescs, InspectMode: false, - }) + }); err != nil { + return err + } + + // In CI mode, exit non-zero when findings or runtime failures remain. + if err := ciExitError(results); err != nil { + exitOnConsentError(err) + return err + } + return nil } func printBanner() { diff --git a/internal/consent/consent.go b/internal/consent/consent.go new file mode 100644 index 0000000..57d1674 --- /dev/null +++ b/internal/consent/consent.go @@ -0,0 +1,256 @@ +// Package consent implements interactive per-server consent for starting stdio +// MCP servers. Before any subprocess is launched, the user is shown the command +// (and redacted env) for each stdio server and asked to allow or decline it. +// +// The consent UI is diagnostic chrome, not scan output, so callers render it on +// stderr and read answers from stdin. +package consent + +import ( + "bufio" + "fmt" + "io" + "sort" + "strings" + "unicode" + + "github.com/go-authgate/agent-scanner/internal/models" +) + +// stdioItem and remoteItem are the enumerated servers across all clients. +type stdioItem struct { + configPath string + name string + server *models.StdioServer +} + +type remoteItem struct { + configPath string + name string + server *models.RemoteServer +} + +// CollectConsent prompts the user per stdio MCP server before any subprocess is +// started and returns the set of declined servers (keyed by ServerRef). Remote +// servers start no subprocess and are auto-allowed. Prompts are written to out +// and answers read from in; an empty line or EOF is treated as "decline". +func CollectConsent( + clients []*models.ClientToInspect, + out io.Writer, + in io.Reader, +) map[models.ServerRef]bool { + stdioItems, remoteItems := enumerate(clients) + + declined := map[models.ServerRef]bool{} + if len(stdioItems) == 0 && len(remoteItems) == 0 { + return declined + } + + if len(stdioItems) > 0 { + fmt.Fprint(out, + "Agent Scanner will launch stdio MCP servers as subprocesses to inspect their tools.\n"+ + "Review each command below and confirm whether Agent Scanner may start it.\n"+ + "Tip: pass --dangerously-run-mcp-servers to start every server without prompting.\n\n") + + reader := bufio.NewReader(in) + fmt.Fprint(out, "Stdio MCP servers (require consent):\n") + for idx, item := range stdioItems { + fmt.Fprintf(out, "\n [%d] %s\n", idx+1, sanitize(item.name)) + fmt.Fprintf(out, " config : %s\n", sanitize(item.configPath)) + fmt.Fprintf(out, " command: %s\n", sanitize(renderCommand(item.server))) + if env := renderEnvRedacted(item.server); env != "" { + fmt.Fprintf(out, " env : %s\n", sanitize(env)) + } + fmt.Fprintf( + out, + " Allow Agent Scanner to start '%s'? [y/N]: ", + sanitize(item.name), + ) + if readYesNo(reader) { + fmt.Fprintf(out, " Allowed: '%s' will be started.\n", sanitize(item.name)) + } else { + declined[models.ServerRef{ConfigPath: item.configPath, Name: item.name}] = true + fmt.Fprintf(out, " Declined: '%s' will not be started.\n", sanitize(item.name)) + } + } + } + + if len(remoteItems) > 0 { + fmt.Fprint(out, "\nRemote MCP servers (no subprocess — auto-allowed):\n") + for _, item := range remoteItems { + typeStr := string(item.server.GetServerType()) + fmt.Fprintf( + out, + " - %s (%s, %s) %s\n", + sanitize(item.name), + sanitize(typeStr), + sanitize(item.server.URL), + sanitize(item.configPath), + ) + } + } + + // The "proceeding" summary only makes sense when stdio servers were prompted. + if len(stdioItems) > 0 { + allowed := len(stdioItems) - len(declined) + msg := fmt.Sprintf("\nProceeding with %d of %d stdio servers.", allowed, len(stdioItems)) + if len(declined) > 0 { + msg += fmt.Sprintf(" Skipped: %d.", len(declined)) + } + fmt.Fprintln(out, msg) + if len(declined) > 0 { + fmt.Fprint(out, + "Note: declined servers will not be started on this machine. Analysis "+ + "results may still appear for them if recognized from prior scans — "+ + "those are not based on your machine's behavior.\n") + } + } + return declined +} + +// enumerate collects stdio and remote servers across all clients, skipping +// config entries that failed to parse. +func enumerate(clients []*models.ClientToInspect) ([]stdioItem, []remoteItem) { + var stdioItems []stdioItem + var remoteItems []remoteItem + // Prompt for each (configPath, name) at most once, so duplicate scan paths + // or multiple clients referencing the same config don't double-prompt or + // make the decision order-dependent. First occurrence wins, and clients are + // sorted so that "first" is stable across runs regardless of discovery order. + sorted := make([]*models.ClientToInspect, len(clients)) + copy(sorted, clients) + sort.Slice(sorted, func(i, j int) bool { + if sorted[i].ClientPath != sorted[j].ClientPath { + return sorted[i].ClientPath < sorted[j].ClientPath + } + return sorted[i].Name < sorted[j].Name + }) + seen := map[models.ServerRef]bool{} + for _, client := range sorted { + // Deterministic ordering: sort config paths, then server names. + paths := make([]string, 0, len(client.MCPConfigs)) + for path := range client.MCPConfigs { + paths = append(paths, path) + } + sort.Strings(paths) + for _, configPath := range paths { + configOrErr := client.MCPConfigs[configPath] + if configOrErr.Error != nil || configOrErr.Config == nil { + continue + } + servers := configOrErr.Config.GetServers() + names := make([]string, 0, len(servers)) + for name := range servers { + names = append(names, name) + } + sort.Strings(names) + for _, name := range names { + ref := models.ServerRef{ConfigPath: configPath, Name: name} + if seen[ref] { + continue + } + seen[ref] = true + switch s := servers[name].(type) { + case *models.StdioServer: + stdioItems = append(stdioItems, stdioItem{configPath, name, s}) + case *models.RemoteServer: + remoteItems = append(remoteItems, remoteItem{configPath, name, s}) + } + } + } + } + return stdioItems, remoteItems +} + +// renderCommand renders the launch command with shell-quoted arguments. +func renderCommand(server *models.StdioServer) string { + parts := make([]string, 0, 1+len(server.Args)) + parts = append(parts, shellQuote(server.Command)) + for _, arg := range server.Args { + parts = append(parts, shellQuote(arg)) + } + return strings.Join(parts, " ") +} + +// renderEnvRedacted renders env keys as "KEY=***". Values are never echoed. +func renderEnvRedacted(server *models.StdioServer) string { + if len(server.Env) == 0 { + return "" + } + keys := make([]string, 0, len(server.Env)) + for k := range server.Env { + keys = append(keys, k) + } + sort.Strings(keys) + for i, k := range keys { + keys[i] = k + "=***" + } + return strings.Join(keys, ", ") +} + +// readYesNo reads one line and returns true only for "y"/"yes" (case +// insensitive). Empty line and EOF both mean "no". +func readYesNo(reader *bufio.Reader) bool { + line, err := reader.ReadString('\n') + if err != nil && line == "" { + return false // EOF with no input + } + switch strings.ToLower(strings.TrimSpace(line)) { + case "y", "yes": + return true + default: + return false + } +} + +// sanitize replaces control characters — including newlines, tabs, and ANSI +// escape sequences — with a visible escaped form. Server names, commands, and +// env keys come from untrusted local config files; without this a malicious +// config could embed terminal control codes to spoof or rewrite the consent +// prompt (e.g. hide the real command or fake an "Allowed" line). +func sanitize(s string) string { + var b strings.Builder + for _, r := range s { + if unicode.IsControl(r) { + switch { + case r <= 0xff: + fmt.Fprintf(&b, `\x%02x`, r) + case r <= 0xffff: + fmt.Fprintf(&b, `\u%04x`, r) + default: + fmt.Fprintf(&b, `\U%08x`, r) + } + continue + } + b.WriteRune(r) + } + return b.String() +} + +// shellSafe reports whether r can appear unquoted in a shell command. +func shellSafe(r rune) bool { + switch { + case r >= 'a' && r <= 'z', r >= 'A' && r <= 'Z', r >= '0' && r <= '9': + return true + } + switch r { + case '_', '-', '.', '/', '=', ':', '@', '+', ',': + return true + } + return false +} + +// shellQuote single-quotes s if it contains characters that are unsafe to leave +// bare, using POSIX single-quote escaping. The rendered command is shown for +// human review in the consent prompt and is never executed via a shell, so +// POSIX quoting is used purely for legibility (it is not Windows-shell correct). +func shellQuote(s string) string { + if s == "" { + return "''" + } + if strings.IndexFunc(s, func(r rune) bool { return !shellSafe(r) }) == -1 { + return s + } + // Escape embedded single quotes: ' -> '\'' + return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'" +} diff --git a/internal/consent/consent_test.go b/internal/consent/consent_test.go new file mode 100644 index 0000000..cc95bce --- /dev/null +++ b/internal/consent/consent_test.go @@ -0,0 +1,199 @@ +package consent + +import ( + "bytes" + "strings" + "testing" + + "github.com/go-authgate/agent-scanner/internal/models" +) + +// clientWith builds a single client whose config at "cfg.json" holds the given +// servers, using a real ClaudeConfigFile so GetServers() behaves normally. +func clientWith(servers map[string]models.ServerConfigJSON) []*models.ClientToInspect { + return []*models.ClientToInspect{{ + Name: "test", + MCPConfigs: map[string]models.MCPConfigOrError{ + "cfg.json": {Config: &models.ClaudeConfigFile{MCPServers: servers}}, + }, + }} +} + +func TestCollectConsent_AllowAndDecline(t *testing.T) { + clients := clientWith(map[string]models.ServerConfigJSON{ + "alpha": {Command: "npx", Args: []string{"-y", "alpha"}}, + "beta": {Command: "uvx", Args: []string{"beta"}}, + }) + // Servers are prompted in sorted order: alpha, then beta. Allow alpha, decline beta. + var out bytes.Buffer + declined := CollectConsent(clients, &out, strings.NewReader("y\nn\n")) + + if declined[models.ServerRef{ConfigPath: "cfg.json", Name: "alpha"}] { + t.Error("alpha should be allowed") + } + if !declined[models.ServerRef{ConfigPath: "cfg.json", Name: "beta"}] { + t.Error("beta should be declined") + } + if got := out.String(); !strings.Contains(got, "Proceeding with 1 of 2 stdio servers") { + t.Errorf("missing summary line, got:\n%s", got) + } +} + +func TestCollectConsent_EmptyInputDeclinesAll(t *testing.T) { + clients := clientWith(map[string]models.ServerConfigJSON{ + "alpha": {Command: "npx"}, + }) + var out bytes.Buffer + // EOF immediately → decline. + declined := CollectConsent(clients, &out, strings.NewReader("")) + if !declined[models.ServerRef{ConfigPath: "cfg.json", Name: "alpha"}] { + t.Error("expected alpha declined on EOF") + } +} + +func TestCollectConsent_YesVariants(t *testing.T) { + for _, answer := range []string{"y\n", "Y\n", "yes\n", "YES\n", " yes \n"} { + clients := clientWith(map[string]models.ServerConfigJSON{"alpha": {Command: "npx"}}) + var out bytes.Buffer + declined := CollectConsent(clients, &out, strings.NewReader(answer)) + if len(declined) != 0 { + t.Errorf("answer %q: expected allow (no declines), got %v", answer, declined) + } + } +} + +func TestCollectConsent_EnvRedacted(t *testing.T) { + clients := clientWith(map[string]models.ServerConfigJSON{ + "alpha": {Command: "npx", Env: map[string]string{"API_KEY": "super-secret-value"}}, + }) + var out bytes.Buffer + CollectConsent(clients, &out, strings.NewReader("n\n")) + got := out.String() + if strings.Contains(got, "super-secret-value") { + t.Error("env value must never be echoed") + } + if !strings.Contains(got, "API_KEY=***") { + t.Errorf("expected redacted env, got:\n%s", got) + } +} + +func TestCollectConsent_RemoteAutoAllowed(t *testing.T) { + clients := clientWith(map[string]models.ServerConfigJSON{ + "remote": {URL: "https://example.com/mcp"}, + }) + var out bytes.Buffer + declined := CollectConsent(clients, &out, strings.NewReader("")) + if len(declined) != 0 { + t.Errorf("remote servers are auto-allowed, got declines: %v", declined) + } + got := out.String() + if !strings.Contains(got, "auto-allowed") || !strings.Contains(got, "https://example.com/mcp") { + t.Errorf("expected remote listing, got:\n%s", got) + } + // Remote-only runs must not show the stdio launch banner or the "0 of 0" summary. + if strings.Contains(got, "launch stdio MCP servers") { + t.Errorf("remote-only run should not show stdio banner, got:\n%s", got) + } + if strings.Contains(got, "Proceeding with") { + t.Errorf("remote-only run should not show stdio proceeding summary, got:\n%s", got) + } +} + +func TestCollectConsent_NoServers(t *testing.T) { + var out bytes.Buffer + declined := CollectConsent(nil, &out, strings.NewReader("")) + if len(declined) != 0 { + t.Errorf("expected empty declined, got %v", declined) + } + if out.Len() != 0 { + t.Errorf("expected no output for empty input, got:\n%s", out.String()) + } +} + +func TestSanitize(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {"plain", "npx server", "npx server"}, + {"newline", "evil\nAllowed", `evil\x0aAllowed`}, + {"carriage return", "a\rb", `a\x0db`}, + {"tab", "a\tb", `a\x09b`}, + {"ansi escape", "\x1b[31mred\x1b[0m", `\x1b[31mred\x1b[0m`}, + {"keeps unicode letters", "café", "café"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := sanitize(tt.in); got != tt.want { + t.Errorf("sanitize(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestCollectConsent_SanitizesControlChars(t *testing.T) { + // A malicious server name with embedded newline + ANSI must not reach the + // terminal raw; it should appear escaped in the prompt. + clients := clientWith(map[string]models.ServerConfigJSON{ + "evil\n\x1b[2KAllowed: fake": {Command: "x"}, + }) + var out bytes.Buffer + CollectConsent(clients, &out, strings.NewReader("n\n")) + got := out.String() + if strings.Contains(got, "\x1b") || strings.Contains(got, "\n\x1b") { + t.Errorf("raw control characters leaked into prompt output:\n%q", got) + } + if !strings.Contains(got, `\x1b`) { + t.Errorf("expected escaped ANSI in output, got:\n%q", got) + } +} + +func TestCollectConsent_DeduplicatesByServerRef(t *testing.T) { + // Two clients reference the same config path + server name. The user must be + // prompted only once for that server. + mk := func() *models.ClientToInspect { + return &models.ClientToInspect{ + Name: "test", + MCPConfigs: map[string]models.MCPConfigOrError{ + "cfg.json": {Config: &models.ClaudeConfigFile{ + MCPServers: map[string]models.ServerConfigJSON{"dup": {Command: "npx"}}, + }}, + }, + } + } + clients := []*models.ClientToInspect{mk(), mk()} + + var out bytes.Buffer + declined := CollectConsent(clients, &out, strings.NewReader("n\n")) + + if got := strings.Count(out.String(), "Allow Agent Scanner to start 'dup'?"); got != 1 { + t.Errorf("expected exactly 1 prompt for duplicate server, got %d", got) + } + if !declined[models.ServerRef{ConfigPath: "cfg.json", Name: "dup"}] { + t.Error("expected 'dup' to be declined") + } + if !strings.Contains(out.String(), "Proceeding with 0 of 1 stdio servers") { + t.Errorf("expected a single deduplicated stdio server, got:\n%s", out.String()) + } +} + +func TestRenderCommand_ShellQuoting(t *testing.T) { + tests := []struct { + name string + s *models.StdioServer + want string + }{ + {"plain", &models.StdioServer{Command: "npx", Args: []string{"-y", "pkg"}}, "npx -y pkg"}, + {"space arg", &models.StdioServer{Command: "cmd", Args: []string{"a b"}}, "cmd 'a b'"}, + {"quote arg", &models.StdioServer{Command: "cmd", Args: []string{"it's"}}, `cmd 'it'\''s'`}, + {"empty arg", &models.StdioServer{Command: "cmd", Args: []string{""}}, "cmd ''"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := renderCommand(tt.s); got != tt.want { + t.Errorf("renderCommand = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/internal/inspect/inspector.go b/internal/inspect/inspector.go index 8984929..66defc5 100644 --- a/internal/inspect/inspector.go +++ b/internal/inspect/inspector.go @@ -181,6 +181,13 @@ func (i *inspector) InspectClient( servers := configOrErr.Config.GetServers() for name, serverCfg := range servers { + // Skip stdio servers the user declined during consent. Remote + // servers start no subprocess and are always inspected. + if serverCfg.GetServerType() == models.ServerTypeStdio && + client.Declined[models.ServerRef{ConfigPath: configPath, Name: name}] { + slog.Debug("skipping declined server", "name", name, "config", configPath) + continue + } name, serverCfg, configPath := name, serverCfg, configPath g.Go(func() error { ext, _ := i.InspectServer(ctx, name, serverCfg) diff --git a/internal/inspect/inspector_test.go b/internal/inspect/inspector_test.go new file mode 100644 index 0000000..3802f9d --- /dev/null +++ b/internal/inspect/inspector_test.go @@ -0,0 +1,105 @@ +package inspect + +import ( + "context" + "errors" + "sort" + "sync" + "testing" + + "github.com/go-authgate/agent-scanner/internal/mcpclient" + "github.com/go-authgate/agent-scanner/internal/models" +) + +// recordingClient records the configs it is asked to connect to and always +// fails the connection, so InspectServer returns an error extension without +// needing a full live session. +type recordingClient struct { + mu sync.Mutex + connected []models.ServerConfig +} + +func (c *recordingClient) Connect( + _ context.Context, + cfg models.ServerConfig, + _ int, +) (mcpclient.Session, error) { + c.mu.Lock() + c.connected = append(c.connected, cfg) + c.mu.Unlock() + return nil, errors.New("connect disabled in test") +} + +func (c *recordingClient) connectedCommands() []string { + c.mu.Lock() + defer c.mu.Unlock() + var cmds []string + for _, cfg := range c.connected { + if s, ok := cfg.(*models.StdioServer); ok { + cmds = append(cmds, s.Command) + } + } + sort.Strings(cmds) + return cmds +} + +func TestInspectClient_SkipsDeclinedStdioServers(t *testing.T) { + rc := &recordingClient{} + insp := NewInspector(rc, 1) + + client := &models.ClientToInspect{ + Name: "test", + MCPConfigs: map[string]models.MCPConfigOrError{ + "cfg.json": {Config: &models.ClaudeConfigFile{ + MCPServers: map[string]models.ServerConfigJSON{ + "allowed": {Command: "allowed-cmd"}, + "declined": {Command: "declined-cmd"}, + }, + }}, + }, + Declined: map[models.ServerRef]bool{ + {ConfigPath: "cfg.json", Name: "declined"}: true, + }, + } + + if _, err := insp.InspectClient(context.Background(), client, false); err != nil { + t.Fatalf("InspectClient: %v", err) + } + + got := rc.connectedCommands() + if len(got) != 1 || got[0] != "allowed-cmd" { + t.Errorf("expected only 'allowed-cmd' to be connected, got %v", got) + } +} + +func TestInspectClient_DeclinedRemoteStillInspected(t *testing.T) { + rc := &recordingClient{} + insp := NewInspector(rc, 1) + + // A declined entry that happens to match a remote server name must NOT skip + // it — only stdio servers are subject to consent. + client := &models.ClientToInspect{ + Name: "test", + MCPConfigs: map[string]models.MCPConfigOrError{ + "cfg.json": {Config: &models.ClaudeConfigFile{ + MCPServers: map[string]models.ServerConfigJSON{ + "remote": {URL: "https://example.com/mcp"}, + }, + }}, + }, + Declined: map[models.ServerRef]bool{ + {ConfigPath: "cfg.json", Name: "remote"}: true, + }, + } + + if _, err := insp.InspectClient(context.Background(), client, false); err != nil { + t.Fatalf("InspectClient: %v", err) + } + + rc.mu.Lock() + n := len(rc.connected) + rc.mu.Unlock() + if n != 1 { + t.Errorf("remote server should still be inspected, got %d connect calls", n) + } +} diff --git a/internal/models/client.go b/internal/models/client.go index cbfb63d..12babed 100644 --- a/internal/models/client.go +++ b/internal/models/client.go @@ -8,12 +8,23 @@ type CandidateClient struct { SkillsDirPaths []string `json:"skills_dir_paths,omitempty"` } +// ServerRef uniquely identifies a server within a scan by its config file path +// and server name. It is used as the key for per-server consent decisions. +type ServerRef struct { + ConfigPath string + Name string +} + // ClientToInspect represents a client with resolved config paths ready for inspection. type ClientToInspect struct { Name string `json:"name"` ClientPath string `json:"client_path"` MCPConfigs map[string]MCPConfigOrError `json:"mcp_configs"` SkillsDirs map[string][]SkillEntry `json:"skills_dirs,omitempty"` + // Declined holds stdio servers the user declined to start during interactive + // consent. It is runtime state (not part of discovery output) and is read + // during inspection to skip those servers. A nil/empty map means run all. + Declined map[ServerRef]bool `json:"-"` } // MCPConfigOrError holds either a parsed config or an error. diff --git a/internal/pipeline/pipeline.go b/internal/pipeline/pipeline.go index d03ec9b..ea05f1c 100644 --- a/internal/pipeline/pipeline.go +++ b/internal/pipeline/pipeline.go @@ -34,6 +34,10 @@ type Config struct { ChecksPerServer int InspectOnly bool Verbose bool + // ConsentFn, if set, is called after discovery with all clients to inspect. + // It returns the set of stdio servers the user declined; those servers are + // not started. A nil ConsentFn means run every server without gating. + ConsentFn func([]*models.ClientToInspect) map[models.ServerRef]bool } // Pipeline orchestrates the scan process. @@ -90,6 +94,16 @@ func (p *Pipeline) inspect(ctx context.Context) []models.ScanPathResult { } } + // Collect consent before any stdio subprocess is started. The returned + // declined set is attached to each client so inspection can skip them. + if p.config.ConsentFn != nil { + if declined := p.config.ConsentFn(allClients); len(declined) > 0 { + for _, client := range allClients { + client.Declined = declined + } + } + } + var results []models.ScanPathResult for _, client := range allClients {