diff --git a/.dockerignore b/.dockerignore index bab40b1..db86c12 100644 --- a/.dockerignore +++ b/.dockerignore @@ -11,6 +11,14 @@ atryum.toml docker-compose.override.yml internal/api/web releases +integrations/.venv +integrations/.run +integrations/.harness-config +integrations/results +integrations/*.db +integrations/*.db-journal +integrations/*.log +integrations/*.pid opencode-state node_modules ui/node_modules diff --git a/Dockerfile.integrations b/Dockerfile.integrations new file mode 100644 index 0000000..9821092 --- /dev/null +++ b/Dockerfile.integrations @@ -0,0 +1,41 @@ +FROM golang:1.25-bookworm + +SHELL ["/bin/bash", "-o", "pipefail", "-c"] + +ENV DEBIAN_FRONTEND=noninteractive +ENV PATH=/usr/local/go/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin + +RUN apt-get update \ + && apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + git \ + python3 \ + python3-pip \ + python3-venv \ + && rm -rf /var/lib/apt/lists/* + +RUN curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \ + && apt-get update \ + && apt-get install -y --no-install-recommends nodejs \ + && rm -rf /var/lib/apt/lists/* + +RUN npm install -g @openai/codex @anthropic-ai/claude-code + +WORKDIR /src + +COPY ui/package.json ui/package-lock.json ./ui/ +RUN cd ui && npm ci + +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . + +RUN cd ui && npm run build \ + && rm -rf ../internal/api/web \ + && mkdir -p ../internal/api/web \ + && cp -R dist/. ../internal/api/web/ + +ENTRYPOINT ["bash", "/src/integrations/scripts/agent_harness_integration_tests.sh"] +CMD ["matrix", "--only-passing", "--harnesses", "fake-agent", "--targets", "calculator"] diff --git a/Justfile b/Justfile index c8148a4..adc5819 100644 --- a/Justfile +++ b/Justfile @@ -2,6 +2,7 @@ set shell := ["bash", "-cu"] config := "./atryum.toml" release_dir := "releases" +integration_image := "atryum-integrations" # List justfile targets default: @@ -39,9 +40,11 @@ check: fmt test build: CGO_ENABLED=0 go build -o ./atryum ./cmd/atryum -# Remove generated binaries, release artifacts, and built UI assets +# Remove generated binaries, release artifacts, built UI assets, and integration test debris clean: - rm -rf ./atryum {{release_dir}} ui/dist internal/api/web + rm -rf ./atryum {{release_dir}} ui/dist internal/api/web \ + integrations/.venv integrations/.run integrations/.harness-config integrations/results \ + integrations/*.db integrations/*.db-journal integrations/*.log integrations/*.pid # Build local production-like atryum binary with the local UI embedded build-prod: build-ui build @@ -163,3 +166,17 @@ integration-test harness="fake-agent" auth="no-auth" target="calculator": # Run the full integration matrix (skips unavailable harnesses and placeholder auth) integration-test-matrix *args: integrations/scripts/agent_harness_integration_tests.sh matrix --only-passing {{args}} + +# Build Docker image for integration tests +integration-docker-build: + docker build -f Dockerfile.integrations -t {{integration_image}} . + +# Run integration tests inside the Docker image +integration-docker-test *args: + docker run --rm \ + -e OPENAI_API_KEY \ + -e CODEX_API_KEY \ + -e ANTHROPIC_API_KEY \ + -e AMP_API_KEY \ + -e XAI_API_KEY \ + {{integration_image}} {{args}} diff --git a/README.md b/README.md index 16e171c..15f9a57 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Atryum mediates three kinds of tool calls: - **Pre-tool hooks from agent harnesses.** Managed harnesses (Claude Code, Cursor, amp, Pi) and autonomous ones (Microsoft Foundry, custom orchestrators) post their intended tool call to `POST /api/v1/external/invocations` (when the harness executes the tool itself) or `POST /api/v1/invocations` (when Atryum should execute it). The harness blocks on the response and only proceeds if Atryum returns an approved status. In the hook path Atryum never touches the tool — it just answers "may this call happen." - **Direct MCP proxying.** Agents that speak MCP connect to `POST /mcp/{server}` as their MCP endpoint. Atryum implements the JSON-RPC surface (`initialize`, `notifications/initialized`, `tools/list`, `tools/call`) and proxies calls to the configured upstream — HTTP or stdio. Because Atryum is the MCP client to the upstream, it holds the credentials (OAuth tokens, bearer tokens, custom headers) and the agent never sees them. The same approval engine runs on every `tools/call`. -- **Claude Managed Agents events bridge.** Anthropic's hosted harness runs the agent loop on its own infrastructure and never calls Atryum, so for those sessions Atryum dials *out*: it streams a registered session's [events](https://platform.claude.com/docs/en/managed-agents/events-and-streaming), records the raw session events on a synthetic audit invocation, and — when the session blocks on a tool call (`session.status_idle` / `requires_action`) — runs the normal approval rules and answers Claude with a `user.tool_confirmation` (or `user.custom_tool_result`). Each tool call is also recorded as its own invocation. This gates both built-in and MCP tools. Enable it by declaring one or more `[[managed_agents]]` accounts (each with a `name` and `api_key`) and register sessions via `POST /api/v1/admin/managed-agents/sessions` (set `"account"` to target a specific account when more than one is configured). See `examples/managed-agents/`. +- **Claude Managed Agents events bridge.** Anthropic's hosted harness runs the agent loop on its own infrastructure and never calls Atryum, so for those sessions Atryum dials *out*: it discovers linked Claude sessions, streams their [events](https://platform.claude.com/docs/en/managed-agents/events-and-streaming), records the raw session events on a synthetic audit invocation, and — when the session blocks on a tool call (`session.status_idle` / `requires_action`) — runs the normal approval rules and answers Claude with a `user.tool_confirmation` (or `user.custom_tool_result`). Each tool call is also recorded as its own invocation. This gates both built-in and MCP tools. Enable it by declaring one or more `[[managed_agents]]` accounts (each with a `name`, `workspace`, and `api_key`) and link Claude agents from the Agents UI. Manual session registration remains available at `POST /api/v1/admin/managed-agents/sessions`. See `examples/managed-agents/`. These paths converge on a single service so rules, audit, and the UI work identically regardless of how the call arrived. @@ -73,7 +73,8 @@ Admin (UI and operators): - `/api/v1/admin/rules`, `/{id}` (including reorder/move) - `/api/v1/admin/agents`, `/{id}` - `/api/v1/admin/settings`, `/api/v1/admin/policy` -- `/api/v1/admin/managed-agents/sessions` — register a Claude Managed Agents session for the events bridge to watch (body may include `"account"` to choose which `[[managed_agents]]` entry; returns `501` when no `[[managed_agents]]` account is configured) +- `/api/v1/admin/managed-agents/accounts`, `/managed-agents/agents` — discover configured Anthropic accounts and Claude agents for UI linking +- `/api/v1/admin/managed-agents/sessions` — manually register a Claude Managed Agents session for the events bridge to watch; kept as a debugging escape hatch - `/api/v1/admin/oauth/callback` — OAuth callback for upstream MCP server connect flows ## Frontend @@ -92,6 +93,7 @@ SQLite by default, PostgreSQL optional via `server.database_url`. Both are first - `invocation_events` — append-only lifecycle events. - `approval_rules` — the rule engine. - `agents` — local agent records and their authenticated-ID mappings. +- `managed_agent_bindings` — local Atryum agent to Claude Managed Agent links used for session discovery. - `managed_agent_sessions` — Claude Managed Agents sessions watched by the events bridge, with each one's event cursor for resume-after-restart. ## Config @@ -102,6 +104,7 @@ A single TOML file configures process and bootstrap settings; runtime entities ( [server] listen_addr = ":8080" public_base_url = "http://localhost:8080" # browser-facing API URL for OAuth callbacks +atryum_instance = "" # stable metadata identity; defaults to public_base_url database_path = "./atryum.db" # or set database_url for Postgres database_url = "" # postgres://, postgresql://, sqlite://, file:, or a SQLite path log_level = "info" @@ -115,10 +118,11 @@ machine_key = "" machine_secret = "" connection_timeout_seconds = 5 -[[managed_agents]] # optional, repeatable — one per Anthropic account +[[managed_agents]] # optional, repeatable — one per Anthropic account/workspace key name = "default" # unique label; the session-registration "account" targets it -api_key = "" # Anthropic API key; empty entries are skipped - # env override (single account only): ATRYUM_MANAGED_AGENTS_API_KEY, then ANTHROPIC_API_KEY +workspace = "" # required label when api_key is set; used for display/metadata +api_key = "" # Anthropic API key created in that workspace; empty entries are skipped + # env override (single account only): ATRYUM_MANAGED_AGENTS_API_KEY, then ANTHROPIC_API_KEY; workspace label via ATRYUM_MANAGED_AGENTS_WORKSPACE [[auth]] # optional — repeatable per authorization server issuer = "https://keycloak.example/realms/agents" diff --git a/atryum.example.toml b/atryum.example.toml index 9557c69..499c7ce 100644 --- a/atryum.example.toml +++ b/atryum.example.toml @@ -4,6 +4,9 @@ listen_addr = ":8080" # In local dev this should point at the Atryum API server, not the Vite UI. # In production/Kubernetes it should be the ingress URL. public_base_url = "http://localhost:8080" +# Stable identity written into Claude Managed Agent metadata when Atryum links +# to an Anthropic agent. Defaults to public_base_url when omitted. +atryum_instance = "" # SQLite remains the default. database_path is used when database_url is empty. database_path = "./atryum.db" # Optional: select storage provider by URL scheme. @@ -57,13 +60,18 @@ request_timeout_seconds = 30 # api_key are skipped; with no usable entry the bridge is disabled. See # examples/managed-agents/. # -# Environment overrides for api_key (single-account convenience only): +# Environment overrides for api_key/workspace label (single-account convenience only): # ATRYUM_MANAGED_AGENTS_API_KEY (highest), then ANTHROPIC_API_KEY. These apply # only when zero or one [[managed_agents]] entry is configured. With multiple -# entries each must set its own api_key in TOML. +# entries each must set its own api_key in TOML. Set +# ATRYUM_MANAGED_AGENTS_WORKSPACE when using the env-only API key path. The API +# key itself must be created in the target Anthropic workspace; workspace is a +# label used by Atryum for display and ownership metadata, not an Anthropic +# request selector. [[managed_agents]] name = "default" # unique label; targeted by the session-registration "account" field -api_key = "" +workspace = "" # required label when api_key is set; Anthropic workspace identifier/name +api_key = "" # Anthropic API key created in that workspace # Optional tuning (defaults shown): # base_url = "https://api.anthropic.com" # poll_interval_millis = 1000 @@ -74,6 +82,7 @@ api_key = "" # Add more accounts by repeating the table: # [[managed_agents]] # name = "staging" +# workspace = "staging-workspace" # api_key = "" # Agent sync and AI evaluation settings (org, record type, charter field) diff --git a/cmd/atryum/main.go b/cmd/atryum/main.go index b20b3bf..564c5f2 100644 --- a/cmd/atryum/main.go +++ b/cmd/atryum/main.go @@ -112,6 +112,7 @@ func runServer(args []string) error { oauthRepo := store.NewOAuthRepoWithDialect(db, dialect) rulesRepo := store.NewRulesRepoWithDialect(db, dialect) agentsRepo := store.NewAgentsRepoWithDialect(db, dialect) + managedAgentBindingRepo := store.NewManagedAgentBindingRepoWithDialect(db, dialect) agentSyncSettingsRepo := store.NewAgentSyncSettingsRepoWithDialect(db, dialect) llmConfigsRepo := store.NewLLMConfigsRepoWithDialect(db, dialect) @@ -186,7 +187,7 @@ func runServer(args []string) error { // invocation package interfaces without creating import cycles. var invAgents invocation.AgentLookup if agentsRepo != nil { - invAgents = &agentsLookupAdapter{repo: agentsRepo} + invAgents = &agentsLookupAdapter{repo: agentsRepo, managedBindings: managedAgentBindingRepo} } // Build the evaluator: always create a local evaluator backed by llmConfigsRepo. @@ -219,6 +220,7 @@ func runServer(args []string) error { syncAgentsFn = syncAgents } handler := api.NewHandler(service, serverAdmin, policyRegistry, rulesRepo, agentsRepo, agentSyncSettingsRepo, llmConfigsRepo, syncAgentsFn, backendClient, localEvaluator) + handler.SetManagedAgentBindings(managedAgentBindingRepo) authValidator, err := auth.NewValidator(cfg.Auth, nil) if err != nil { @@ -252,8 +254,12 @@ func runServer(args []string) error { if ma.APIKey == "" { continue } + if ma.Workspace == "" { + return fmt.Errorf("managed_agents workspace label is required for account %q; use an Anthropic API key created in that workspace", emptyDefault(ma.Name, managedagents.DefaultAccountName)) + } acctCfg := managedagents.Config{ Name: ma.Name, + Workspace: ma.Workspace, BaseURL: ma.BaseURL, APIKey: ma.APIKey, PollInterval: time.Duration(ma.PollIntervalMillis) * time.Millisecond, @@ -277,6 +283,12 @@ func runServer(args []string) error { if err != nil { return fmt.Errorf("configure managed agents bridge: %w", err) } + instanceName := cfg.Server.AtryumInstance + if instanceName == "" { + instanceName = cfg.Server.PublicBaseURL + } + managedSvc.SetInstanceName(instanceName) + managedSvc.SetBindings(&managedBindingStoreAdapter{repo: managedAgentBindingRepo}) if err := managedSvc.Start(context.Background()); err != nil { return fmt.Errorf("start managed agents bridge: %w", err) } @@ -351,6 +363,10 @@ func (a *managedSessionStoreAdapter) List(ctx context.Context) ([]managedagents. return out, nil } +func (a *managedSessionStoreAdapter) Delete(ctx context.Context, sessionID string) error { + return a.repo.Delete(ctx, sessionID) +} + func (a *managedSessionStoreAdapter) UpdateCursor(ctx context.Context, sessionID, lastEventID string) error { return a.repo.UpdateCursor(ctx, sessionID, lastEventID) } @@ -367,6 +383,27 @@ func managedSessionToReg(row store.ManagedAgentSession) managedagents.SessionReg } } +type managedBindingStoreAdapter struct { + repo *store.ManagedAgentBindingRepo +} + +func (a *managedBindingStoreAdapter) List(ctx context.Context) ([]managedagents.AgentBinding, error) { + rows, err := a.repo.List(ctx) + if err != nil { + return nil, err + } + out := make([]managedagents.AgentBinding, 0, len(rows)) + for _, row := range rows { + out = append(out, managedagents.AgentBinding{ + AgentCUID: row.AgentCUID, + Account: row.Account, + ClaudeAgentID: row.ClaudeAgentID, + ClaudeAgentName: row.ClaudeAgentName, + }) + } + return out, nil +} + // managedAuditAdapter bridges the invocation/event repos → // managedagents.InvocationAuditStore for the synthetic per-session audit row. type managedAuditAdapter struct { @@ -460,11 +497,23 @@ func defaultUserDatabasePath() (string, error) { // agentsLookupAdapter bridges store.AgentsRepo → invocation.AgentLookup. type agentsLookupAdapter struct { - repo *store.AgentsRepo + repo *store.AgentsRepo + managedBindings *store.ManagedAgentBindingRepo } func (a *agentsLookupAdapter) GetByAgentID(ctx context.Context, agentID string) (invocation.AgentRecord, error) { rec, err := a.repo.GetByAgentID(ctx, agentID) + if err == nil { + return invocation.AgentRecord{ID: rec.ID, VMCUID: rec.VMCUID, VMOrganizationCUID: rec.VMOrganizationCUID, Charter: rec.Charter}, nil + } + if a.managedBindings == nil { + return invocation.AgentRecord{}, err + } + binding, bindErr := a.managedBindings.GetByClaudeAgentID(ctx, "", agentID) + if bindErr != nil { + return invocation.AgentRecord{}, err + } + rec, err = a.repo.Get(ctx, binding.AgentCUID) if err != nil { return invocation.AgentRecord{}, err } @@ -568,6 +617,13 @@ func truthyEnv(name string) bool { return value == "1" || value == "true" || value == "TRUE" || value == "yes" || value == "YES" } +func emptyDefault(value, fallback string) string { + if value == "" { + return fallback + } + return value +} + // credentialAdapter bridges store.OAuthRepo into the narrow // mcp.CredentialStore interface the resolver consumes. Keeps the mcp // package independent of the concrete OAuthRepo/OAuthCredential types. diff --git a/examples/managed-agents/README.md b/examples/managed-agents/README.md index 3ffdebe..f54e149 100644 --- a/examples/managed-agents/README.md +++ b/examples/managed-agents/README.md @@ -65,15 +65,17 @@ Because Atryum answers the harness's own confirmation prompts, it can gate the ### 1. Enable the bridge in `atryum.toml` -Declare one `[[managed_agents]]` table per Anthropic account/workspace. The -`name` is a unique label the session-registration API uses to target a specific -account. +Declare one `[[managed_agents]]` table per Anthropic account/workspace API key. +The `name` is a unique label the session-registration API uses to target a +specific account. ```toml [[managed_agents]] name = "default" # unique label; targeted by the registration "account" field -# Anthropic API key. Env overrides (single account only): +workspace = "anthropic-workspace-name-or-id" # display/metadata label +# Anthropic API key created in that workspace. Env overrides (single account only): # ATRYUM_MANAGED_AGENTS_API_KEY, then ANTHROPIC_API_KEY. +# If using env for the key, set ATRYUM_MANAGED_AGENTS_WORKSPACE too. api_key = "sk-ant-..." # Optional tuning (defaults shown): # base_url = "https://api.anthropic.com" @@ -85,13 +87,17 @@ api_key = "sk-ant-..." # Watch a second account by repeating the table: # [[managed_agents]] # name = "staging" +# workspace = "staging-workspace" # api_key = "sk-ant-..." ``` Entries with an empty `api_key` are skipped; when no account has a usable key the bridge is disabled and the admin endpoint returns `501`. The `ATRYUM_MANAGED_AGENTS_API_KEY` / `ANTHROPIC_API_KEY` env overrides apply only -when zero or one `[[managed_agents]]` entry is configured. +when zero or one `[[managed_agents]]` entry is configured. `workspace` is +required whenever `api_key` is set, but it is not sent as an Anthropic request +selector: Anthropic API keys are already workspace-scoped, so use an API key +created in the workspace whose Claude agents you want to list. ### 2. Create an agent whose tools ask for confirmation @@ -112,11 +118,16 @@ curl -sS https://api.anthropic.com/v1/agents \ }' ``` -Create an environment and a session as usual (see the -[quickstart](https://platform.claude.com/docs/en/managed-agents/quickstart)), -and note the session ID. +Create an environment and sessions as usual (see the +[quickstart](https://platform.claude.com/docs/en/managed-agents/quickstart)). -### 3. Register the session with Atryum +### 3. Link the Claude agent in Atryum + +Open the Agents page, edit the Atryum agent you want rules to apply to, and +select the Claude Managed Agent. Atryum writes ownership metadata to the Claude +agent and discovers its sessions automatically. + +Manual session registration still exists as an escape hatch: ```bash curl -sS -X POST http://localhost:8080/api/v1/admin/managed-agents/sessions \ @@ -129,10 +140,10 @@ curl -sS -X POST http://localhost:8080/api/v1/admin/managed-agents/sessions \ }' ``` -Atryum starts watching immediately and resumes watching registered sessions on -restart (the cursor is persisted, so it replays anything missed). Send the -session a user message; blocking tool calls now flow through your Atryum rules -and appear live in the invocations UI. +Atryum starts watching linked sessions as it discovers them and resumes watched +sessions on restart (the cursor is persisted, so it replays anything missed). +Send the session a user message; blocking tool calls now flow through your +Atryum rules and appear live in the invocations UI. ### Approval rules diff --git a/integrations/lib/atryum.sh b/integrations/lib/atryum.sh index 4bdc7c0..978a481 100644 --- a/integrations/lib/atryum.sh +++ b/integrations/lib/atryum.sh @@ -69,7 +69,11 @@ start_atryum() { local config_path="$1" build_atryum log "Starting atryum on :${ATRYUM_PORT} (config: $config_path)" - ATRYUM_MCP_DEBUG="${ATRYUM_MCP_DEBUG:-1}" \ + env \ + -u ANTHROPIC_API_KEY \ + -u ATRYUM_MANAGED_AGENTS_API_KEY \ + -u ATRYUM_MANAGED_AGENTS_WORKSPACE \ + ATRYUM_MCP_DEBUG="${ATRYUM_MCP_DEBUG:-1}" \ "$ATRYUM_BIN" run -config "$config_path" \ >"$RUN_DIR/atryum.log" 2>&1 & echo $! >"$RUN_DIR/atryum.pid" @@ -152,4 +156,4 @@ PY fi done log "Direct MCP verification passed" -} \ No newline at end of file +} diff --git a/integrations/lib/harness.sh b/integrations/lib/harness.sh index eaedf4f..5d96293 100644 --- a/integrations/lib/harness.sh +++ b/integrations/lib/harness.sh @@ -318,6 +318,10 @@ PY install_hook "$harness_id" configure_harness_mcp "$harness_id" "$auth_id" "$target_id" + if [[ "$harness_id" == "codex" && -z "${CODEX_API_KEY:-}" && -n "${OPENAI_API_KEY:-}" ]]; then + export CODEX_API_KEY="$OPENAI_API_KEY" + fi + local invoke log "Running harness=$harness_id auth=$auth_id target=$target_id" @@ -395,4 +399,4 @@ PY fi done log "Harness output contained expected value(s): ${expect//|/, }" -} \ No newline at end of file +} diff --git a/integrations/scripts/agent_harness_integration_tests.sh b/integrations/scripts/agent_harness_integration_tests.sh index 813bb71..293250d 100755 --- a/integrations/scripts/agent_harness_integration_tests.sh +++ b/integrations/scripts/agent_harness_integration_tests.sh @@ -113,8 +113,14 @@ run_single_case() { template="$(auth_protocol_template "$auth_id")" config_path="$RUN_DIR/atryum.toml" render_atryum_config "$auth_id" "$target_id" "$template" "$config_path" - start_atryum "$config_path" - seed_auto_approve_rules + start_atryum "$config_path" || { + write_result "$case_name" "failed" "atryum startup failed" + return 1 + } + seed_auto_approve_rules || { + write_result "$case_name" "failed" "failed to seed auto-approve rule" + return 1 + } if (( skip_direct == 0 )); then verify_upstream_direct "$target_id" "$auth_id" || { @@ -210,4 +216,4 @@ case "$CMD" in *) die "unknown command: $CMD (try: help)" ;; -esac \ No newline at end of file +esac diff --git a/internal/api/handlers.go b/internal/api/handlers.go index db65654..2134a4b 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -13,6 +13,7 @@ import ( "log" "mime" "net/http" + "net/url" "os" "sort" "strconv" @@ -99,6 +100,12 @@ type agentsRepo interface { DeleteAll(ctx context.Context) error } +type managedAgentBindingsRepo interface { + ListByAgent(ctx context.Context, agentCUID string) ([]store.ManagedAgentBinding, error) + GetByClaudeAgentID(ctx context.Context, account, claudeAgentID string) (store.ManagedAgentBinding, error) + ReplaceForAgent(ctx context.Context, agentCUID string, bindings []store.ManagedAgentBinding) error +} + type agentSyncSettingsRepo interface { Get(ctx context.Context) (store.AgentSyncSettings, error) Save(ctx context.Context, s store.AgentSyncSettings) error @@ -127,6 +134,7 @@ type Handler struct { agentsRepo agentsRepo agentSyncSettingsRepo agentSyncSettingsRepo llmConfigsRepo llmConfigsRepo + managedAgentBindings managedAgentBindingsRepo backendClient *backendclient.Client summarizeClient invocationSummarizer localSummarizer localInvocationSummarizer @@ -152,9 +160,16 @@ type Handler struct { } // managedAgentsAdmin is the slice of the managed-agents service the admin API -// needs to register a session for watching. +// needs for session registration and Claude agent discovery. type managedAgentsAdmin interface { RegisterSession(ctx context.Context, req managedagents.RegisterSessionRequest) (managedagents.SessionRegistration, error) + ListSessions(ctx context.Context) ([]managedagents.SessionRegistration, error) + DeleteSession(ctx context.Context, sessionID string) error + ClearSessions(ctx context.Context) (int, error) + Accounts() []managedagents.AccountInfo + ListAgents(ctx context.Context, req managedagents.ListAgentsRequest) ([]managedagents.AgentInfo, error) + ClaimAgent(ctx context.Context, req managedagents.AgentClaimRequest) (managedagents.AgentInfo, error) + ReleaseAgent(ctx context.Context, req managedagents.AgentClaimRequest) error } type PolicyStatusResponse struct { @@ -339,14 +354,15 @@ type RuleListResponse struct { // ─── Agent admin types ──────────────────────────────────────────────────────── type AdminAgent struct { - CUID string `json:"cuid"` - OrgName string `json:"org_name"` - Name string `json:"name"` - Description string `json:"description,omitempty"` - AgentIDs []string `json:"agent_ids"` - SyncedAt time.Time `json:"synced_at"` - Enabled bool `json:"enabled"` - Charter string `json:"charter,omitempty"` + CUID string `json:"cuid"` + OrgName string `json:"org_name"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + AgentIDs []string `json:"agent_ids"` + ClaudeManagedAgents []AdminManagedAgentBinding `json:"claude_managed_agents,omitempty"` + SyncedAt time.Time `json:"synced_at"` + Enabled bool `json:"enabled"` + Charter string `json:"charter,omitempty"` // Synced is true when this agent originated from a ValidMind sync // (vm_organization_cuid is non-empty). Synced agents cannot be deleted // manually — they are removed by re-syncing with a different org/record-type. @@ -354,19 +370,48 @@ type AdminAgent struct { } type AdminAgentInput struct { - Enabled bool `json:"enabled"` - AgentIDs []string `json:"agent_ids,omitempty"` - Name string `json:"name,omitempty"` - Description string `json:"description,omitempty"` - Charter string `json:"charter,omitempty"` + Enabled bool `json:"enabled"` + AgentIDs []string `json:"agent_ids,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Charter string `json:"charter,omitempty"` + ClaudeManagedAgents *[]AdminManagedAgentBinding `json:"claude_managed_agents,omitempty"` + ForceClaudeManagedAgentConnect bool `json:"force_claude_managed_agent_connect,omitempty"` } type AdminAgentCreateInput struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Enabled bool `json:"enabled"` - AgentIDs []string `json:"agent_ids,omitempty"` - Charter string `json:"charter,omitempty"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Enabled bool `json:"enabled"` + AgentIDs []string `json:"agent_ids,omitempty"` + Charter string `json:"charter,omitempty"` + ClaudeManagedAgents []AdminManagedAgentBinding `json:"claude_managed_agents,omitempty"` + ForceClaudeManagedAgentConnect bool `json:"force_claude_managed_agent_connect,omitempty"` +} + +type AdminManagedAgentBinding struct { + ID string `json:"id,omitempty"` + Account string `json:"account"` + ClaudeAgentID string `json:"claude_agent_id"` + ClaudeAgentName string `json:"claude_agent_name,omitempty"` + ClaudeAgentModel string `json:"claude_agent_model,omitempty"` + ClaudeAgentVersion int `json:"claude_agent_version,omitempty"` +} + +type ManagedAgentAccountListResponse struct { + Items []managedagents.AccountInfo `json:"items"` +} + +type ManagedAgentSessionListResponse struct { + Items []managedagents.SessionRegistration `json:"items"` +} + +type ManagedAgentSessionClearResponse struct { + Deleted int `json:"deleted"` +} + +type ManagedAgentListResponse struct { + Items []managedagents.AgentInfo `json:"items"` } type AgentListResponse struct { @@ -388,6 +433,158 @@ func toAdminAgent(a store.AgentRecord) AdminAgent { } } +func (h *Handler) toAdminAgent(ctx context.Context, a store.AgentRecord) AdminAgent { + out := toAdminAgent(a) + if h.managedAgentBindings == nil { + return out + } + bindings, err := h.managedAgentBindings.ListByAgent(ctx, a.ID) + if err != nil { + return out + } + out.ClaudeManagedAgents = toAdminManagedAgentBindings(bindings) + return out +} + +func toAdminManagedAgentBindings(bindings []store.ManagedAgentBinding) []AdminManagedAgentBinding { + out := make([]AdminManagedAgentBinding, 0, len(bindings)) + for _, b := range bindings { + out = append(out, AdminManagedAgentBinding{ + ID: b.ID, + Account: b.Account, + ClaudeAgentID: b.ClaudeAgentID, + ClaudeAgentName: b.ClaudeAgentName, + ClaudeAgentModel: b.ClaudeAgentModel, + ClaudeAgentVersion: b.ClaudeAgentVersion, + }) + } + return out +} + +func toStoreManagedAgentBindings(agentCUID string, bindings []AdminManagedAgentBinding) []store.ManagedAgentBinding { + out := make([]store.ManagedAgentBinding, 0, len(bindings)) + seen := make(map[string]bool, len(bindings)) + for _, b := range bindings { + account := strings.TrimSpace(b.Account) + if account == "" { + account = managedagents.DefaultAccountName + } + claudeAgentID := strings.TrimSpace(b.ClaudeAgentID) + if claudeAgentID == "" { + continue + } + key := account + "\x00" + claudeAgentID + if seen[key] { + continue + } + seen[key] = true + id := strings.TrimSpace(b.ID) + if id == "" { + id = uuid.NewString() + } + out = append(out, store.ManagedAgentBinding{ + ID: id, + AgentCUID: agentCUID, + Account: account, + ClaudeAgentID: claudeAgentID, + ClaudeAgentName: strings.TrimSpace(b.ClaudeAgentName), + ClaudeAgentModel: strings.TrimSpace(b.ClaudeAgentModel), + ClaudeAgentVersion: b.ClaudeAgentVersion, + }) + } + return out +} + +func (h *Handler) claimManagedAgentBindings(ctx context.Context, agentCUID string, bindings []store.ManagedAgentBinding, force bool) error { + if len(bindings) == 0 { + return nil + } + if h.managedAgents == nil { + return fmt.Errorf("managed agents bridge not configured (set [managed_agents].api_key)") + } + if h.managedAgentBindings != nil { + for _, binding := range bindings { + existing, err := h.managedAgentBindings.GetByClaudeAgentID(ctx, binding.Account, binding.ClaudeAgentID) + if err == nil && existing.AgentCUID != agentCUID { + if !force { + return &managedagents.AgentClaimConflictError{ClaudeAgentID: binding.ClaudeAgentID, Instance: "local", AgentCUID: existing.AgentCUID} + } + continue + } + if err != nil && err != sql.ErrNoRows { + return err + } + } + } + for _, binding := range bindings { + _, err := h.managedAgents.ClaimAgent(ctx, managedagents.AgentClaimRequest{ + Account: binding.Account, + ClaudeAgentID: binding.ClaudeAgentID, + AtryumAgentCUID: agentCUID, + BindingID: binding.ID, + Force: force, + }) + if err != nil { + return err + } + } + return nil +} + +func (h *Handler) releaseManagedAgentBindings(ctx context.Context, agentCUID string, bindings []store.ManagedAgentBinding) { + if h.managedAgents == nil || len(bindings) == 0 { + return + } + for _, binding := range bindings { + _ = h.managedAgents.ReleaseAgent(ctx, managedagents.AgentClaimRequest{ + Account: binding.Account, + ClaudeAgentID: binding.ClaudeAgentID, + AtryumAgentCUID: agentCUID, + BindingID: binding.ID, + }) + } +} + +func (h *Handler) releaseNewManagedAgentClaims(ctx context.Context, agentCUID string, before, after []store.ManagedAgentBinding) { + if len(after) == 0 { + return + } + existing := make(map[string]bool, len(before)) + for _, binding := range before { + existing[binding.Account+"\x00"+binding.ClaudeAgentID] = true + } + for _, binding := range after { + if existing[binding.Account+"\x00"+binding.ClaudeAgentID] { + continue + } + h.releaseManagedAgentBindings(ctx, agentCUID, []store.ManagedAgentBinding{binding}) + } +} + +func (h *Handler) releaseRemovedManagedAgentBindings(ctx context.Context, agentCUID string, before, after []store.ManagedAgentBinding) { + if h.managedAgents == nil || len(before) == 0 { + return + } + keep := make(map[string]bool, len(after)) + for _, binding := range after { + keep[binding.Account+"\x00"+binding.ClaudeAgentID] = true + } + for _, binding := range before { + if keep[binding.Account+"\x00"+binding.ClaudeAgentID] { + continue + } + h.releaseManagedAgentBindings(ctx, agentCUID, []store.ManagedAgentBinding{binding}) + } +} + +func writeManagedAgentClaimError(w http.ResponseWriter, err error) { + if conflict, ok := err.(*managedagents.AgentClaimConflictError); ok { + writeError(w, http.StatusConflict, conflict.Error()) + return + } + writeError(w, http.StatusBadRequest, err.Error()) +} + func parseAgentIDs(raw string) []string { if raw == "" { return []string{} @@ -495,6 +692,13 @@ func (h *Handler) SetManagedAgents(m managedAgentsAdmin) { h.managedAgents = m } +// SetManagedAgentBindings installs the store used to persist Atryum Agent ↔ +// Claude Managed Agent links. It is optional so narrow unit-test handlers can +// omit it. +func (h *Handler) SetManagedAgentBindings(repo managedAgentBindingsRepo) { + h.managedAgentBindings = repo +} + // SetAPIKeyAuth installs the static api-key/secret pair used to protect the // read-only invocation reporting endpoints. func (h *Handler) SetAPIKeyAuth(cfg auth.APIKeyConfig) { @@ -541,6 +745,9 @@ func (h *Handler) Routes() http.Handler { mux.HandleFunc("/api/v1/admin/vm/custom-fields", h.adminVMCustomFields) mux.HandleFunc("/api/v1/admin/oauth/callback", h.oauthCallback) mux.HandleFunc("/api/v1/admin/policy", h.adminPolicy) + mux.HandleFunc("/api/v1/admin/managed-agents/accounts", h.adminManagedAgentAccounts) + mux.HandleFunc("/api/v1/admin/managed-agents/agents", h.adminManagedAgents) + mux.HandleFunc("/api/v1/admin/managed-agents/sessions/", h.adminManagedAgentSessionDetail) mux.HandleFunc("/api/v1/admin/managed-agents/sessions", h.adminManagedAgentSessions) agentRulesHandler := auth.MiddlewareWithOptions(h.authValidator, "/.well-known/oauth-protected-resource", auth.MiddlewareOptions{SkipVerify: h.authDebugSkip, DebugLogIdentity: h.debug})(http.HandlerFunc(h.agentRules)) agentRulesHandler = h.noAuthAgentIDHint(agentRulesHandler) @@ -2573,7 +2780,7 @@ func (h *Handler) adminAgents(w http.ResponseWriter, r *http.Request) { } items := make([]AdminAgent, 0, len(records)) for _, a := range records { - items = append(items, toAdminAgent(a)) + items = append(items, h.toAdminAgent(r.Context(), a)) } writeJSON(w, http.StatusOK, AgentListResponse{Items: items}) @@ -2617,16 +2824,33 @@ func (h *Handler) adminAgents(w http.ResponseWriter, r *http.Request) { Enabled: req.Enabled, Charter: req.Charter, } + var bindings []store.ManagedAgentBinding + if h.managedAgentBindings != nil && len(req.ClaudeManagedAgents) > 0 { + bindings = toStoreManagedAgentBindings(id, req.ClaudeManagedAgents) + if err := h.claimManagedAgentBindings(r.Context(), id, bindings, req.ForceClaudeManagedAgentConnect); err != nil { + writeManagedAgentClaimError(w, err) + return + } + } if err := h.agentsRepo.Create(r.Context(), agent); err != nil { + h.releaseManagedAgentBindings(r.Context(), id, bindings) writeError(w, http.StatusInternalServerError, "failed to create agent") return } + if len(bindings) > 0 { + if err := h.managedAgentBindings.ReplaceForAgent(r.Context(), id, bindings); err != nil { + h.releaseManagedAgentBindings(r.Context(), id, bindings) + _ = h.agentsRepo.Delete(r.Context(), id) + writeError(w, http.StatusInternalServerError, "failed to save managed agent bindings") + return + } + } record, err := h.agentsRepo.Get(r.Context(), id) if err != nil { writeError(w, http.StatusInternalServerError, "failed to retrieve created agent") return } - writeJSON(w, http.StatusCreated, toAdminAgent(record)) + writeJSON(w, http.StatusCreated, h.toAdminAgent(r.Context(), record)) default: writeError(w, http.StatusMethodNotAllowed, "method not allowed") @@ -2662,7 +2886,7 @@ func (h *Handler) adminAgentDetail(w http.ResponseWriter, r *http.Request) { } items := make([]AdminAgent, 0, len(records)) for _, a := range records { - items = append(items, toAdminAgent(a)) + items = append(items, h.toAdminAgent(r.Context(), a)) } writeJSON(w, http.StatusOK, AgentListResponse{Items: items}) return @@ -2681,7 +2905,7 @@ func (h *Handler) adminAgentDetail(w http.ResponseWriter, r *http.Request) { writeError(w, status, "agent not found") return } - writeJSON(w, http.StatusOK, toAdminAgent(record)) + writeJSON(w, http.StatusOK, h.toAdminAgent(r.Context(), record)) case http.MethodPatch: var req AdminAgentInput @@ -2698,7 +2922,42 @@ func (h *Handler) adminAgentDetail(w http.ResponseWriter, r *http.Request) { return } } + var beforeBindings []store.ManagedAgentBinding + var bindings []store.ManagedAgentBinding + managedBindingsTouched := false + if req.ClaudeManagedAgents != nil { + if h.managedAgentBindings == nil { + writeError(w, http.StatusServiceUnavailable, "managed agent bindings not configured") + return + } + if _, err := h.agentsRepo.Get(r.Context(), id); err != nil { + status := http.StatusInternalServerError + if err == sql.ErrNoRows { + status = http.StatusNotFound + } + writeError(w, status, "agent not found") + return + } + var err error + beforeBindings, err = h.managedAgentBindings.ListByAgent(r.Context(), id) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to list managed agent bindings") + return + } + bindings = toStoreManagedAgentBindings(id, *req.ClaudeManagedAgents) + if err := h.claimManagedAgentBindings(r.Context(), id, bindings, req.ForceClaudeManagedAgentConnect); err != nil { + writeManagedAgentClaimError(w, err) + return + } + managedBindingsTouched = true + } + cleanupNewClaims := func() { + if managedBindingsTouched { + h.releaseNewManagedAgentClaims(r.Context(), id, beforeBindings, bindings) + } + } if err := h.agentsRepo.UpdateEnabled(r.Context(), id, req.Enabled); err != nil { + cleanupNewClaims() status := http.StatusInternalServerError if err == sql.ErrNoRows { status = http.StatusNotFound @@ -2709,10 +2968,12 @@ func (h *Handler) adminAgentDetail(w http.ResponseWriter, r *http.Request) { if req.AgentIDs != nil { idsJSON, err := json.Marshal(req.AgentIDs) if err != nil { + cleanupNewClaims() writeError(w, http.StatusInternalServerError, "failed to encode agent_ids") return } if err := h.agentsRepo.UpdateAgentIDs(r.Context(), id, string(idsJSON)); err != nil { + cleanupNewClaims() status := http.StatusInternalServerError if err == sql.ErrNoRows { status = http.StatusNotFound @@ -2721,8 +2982,9 @@ func (h *Handler) adminAgentDetail(w http.ResponseWriter, r *http.Request) { return } } - if req.Name != "" || req.Charter != "" { + if req.Name != "" || req.Description != "" || req.Charter != "" { if err := h.agentsRepo.UpdateMeta(r.Context(), id, req.Name, req.Description, req.Charter); err != nil { + cleanupNewClaims() status := http.StatusInternalServerError if err == sql.ErrNoRows { status = http.StatusNotFound @@ -2731,12 +2993,20 @@ func (h *Handler) adminAgentDetail(w http.ResponseWriter, r *http.Request) { return } } + if managedBindingsTouched { + if err := h.managedAgentBindings.ReplaceForAgent(r.Context(), id, bindings); err != nil { + cleanupNewClaims() + writeError(w, http.StatusInternalServerError, "failed to save managed agent bindings") + return + } + h.releaseRemovedManagedAgentBindings(r.Context(), id, beforeBindings, bindings) + } record, err := h.agentsRepo.Get(r.Context(), id) if err != nil { writeError(w, http.StatusInternalServerError, "failed to retrieve agent") return } - writeJSON(w, http.StatusOK, toAdminAgent(record)) + writeJSON(w, http.StatusOK, h.toAdminAgent(r.Context(), record)) case http.MethodDelete: record, err := h.agentsRepo.Get(r.Context(), id) @@ -3152,6 +3422,24 @@ func (h *Handler) adminManagedAgentSessions(w http.ResponseWriter, r *http.Reque writeError(w, http.StatusNotImplemented, "managed agents bridge not configured (set [managed_agents].api_key)") return } + if r.Method == http.MethodGet { + sessions, err := h.managedAgents.ListSessions(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + writeJSON(w, http.StatusOK, ManagedAgentSessionListResponse{Items: sessions}) + return + } + if r.Method == http.MethodDelete { + deleted, err := h.managedAgents.ClearSessions(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + writeJSON(w, http.StatusOK, ManagedAgentSessionClearResponse{Deleted: deleted}) + return + } if r.Method != http.MethodPost { h.debugf("managed-agents session registration rejected: method not allowed method=%s path=%s remote=%s", r.Method, r.URL.Path, r.RemoteAddr) writeError(w, http.StatusMethodNotAllowed, "method not allowed") @@ -3174,6 +3462,64 @@ func (h *Handler) adminManagedAgentSessions(w http.ResponseWriter, r *http.Reque writeJSON(w, http.StatusOK, resp) } +func (h *Handler) adminManagedAgentSessionDetail(w http.ResponseWriter, r *http.Request) { + if h.managedAgents == nil { + writeError(w, http.StatusNotImplemented, "managed agents bridge not configured (set [managed_agents].api_key)") + return + } + if r.Method != http.MethodDelete { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + rawID := strings.Trim(strings.TrimPrefix(r.URL.Path, "/api/v1/admin/managed-agents/sessions/"), "/") + sessionID, err := url.PathUnescape(rawID) + if err != nil || strings.TrimSpace(sessionID) == "" { + writeError(w, http.StatusNotFound, "not found") + return + } + if err := h.managedAgents.DeleteSession(r.Context(), sessionID); err != nil { + status := http.StatusBadRequest + if err == sql.ErrNoRows { + status = http.StatusNotFound + } + writeError(w, status, err.Error()) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) adminManagedAgentAccounts(w http.ResponseWriter, r *http.Request) { + if h.managedAgents == nil { + writeError(w, http.StatusNotImplemented, "managed agents bridge not configured (set [managed_agents].api_key)") + return + } + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + writeJSON(w, http.StatusOK, ManagedAgentAccountListResponse{Items: h.managedAgents.Accounts()}) +} + +func (h *Handler) adminManagedAgents(w http.ResponseWriter, r *http.Request) { + if h.managedAgents == nil { + writeError(w, http.StatusNotImplemented, "managed agents bridge not configured (set [managed_agents].api_key)") + return + } + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + agents, err := h.managedAgents.ListAgents(r.Context(), managedagents.ListAgentsRequest{ + Account: r.URL.Query().Get("account"), + Query: r.URL.Query().Get("q"), + }) + if err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + writeJSON(w, http.StatusOK, ManagedAgentListResponse{Items: agents}) +} + func (h *Handler) externalInvocationDetail(w http.ResponseWriter, r *http.Request) { id := strings.Trim(strings.TrimPrefix(r.URL.Path, "/api/v1/external/invocations/"), "/") if id == "" { diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go index f86f401..a022eb6 100644 --- a/internal/api/handlers_test.go +++ b/internal/api/handlers_test.go @@ -142,8 +142,12 @@ func (s *stubSummarizer) SummarizeInvocation(_ context.Context, req backendclien } type stubManagedAgentsAdmin struct { - err error - req managedagents.RegisterSessionRequest + err error + req managedagents.RegisterSessionRequest + agents []managedagents.AgentInfo + sessions []managedagents.SessionRegistration + deleted string + cleared bool } func (s *stubManagedAgentsAdmin) RegisterSession(_ context.Context, req managedagents.RegisterSessionRequest) (managedagents.SessionRegistration, error) { @@ -154,6 +158,48 @@ func (s *stubManagedAgentsAdmin) RegisterSession(_ context.Context, req manageda return managedagents.SessionRegistration{SessionID: req.SessionID, Account: req.Account, AgentID: req.AgentID}, nil } +func (s *stubManagedAgentsAdmin) ListSessions(context.Context) ([]managedagents.SessionRegistration, error) { + if s.err != nil { + return nil, s.err + } + return s.sessions, nil +} + +func (s *stubManagedAgentsAdmin) DeleteSession(_ context.Context, sessionID string) error { + s.deleted = sessionID + return s.err +} + +func (s *stubManagedAgentsAdmin) ClearSessions(context.Context) (int, error) { + s.cleared = true + if s.err != nil { + return 0, s.err + } + return len(s.sessions), nil +} + +func (s *stubManagedAgentsAdmin) Accounts() []managedagents.AccountInfo { + return []managedagents.AccountInfo{{Name: managedagents.DefaultAccountName}} +} + +func (s *stubManagedAgentsAdmin) ListAgents(context.Context, managedagents.ListAgentsRequest) ([]managedagents.AgentInfo, error) { + if s.err != nil { + return nil, s.err + } + return s.agents, nil +} + +func (s *stubManagedAgentsAdmin) ClaimAgent(_ context.Context, req managedagents.AgentClaimRequest) (managedagents.AgentInfo, error) { + if s.err != nil { + return managedagents.AgentInfo{}, s.err + } + return managedagents.AgentInfo{ID: req.ClaudeAgentID, Version: 1}, nil +} + +func (s *stubManagedAgentsAdmin) ReleaseAgent(context.Context, managedagents.AgentClaimRequest) error { + return nil +} + type stubAgentSyncSettingsRepo struct { settings store.AgentSyncSettings } diff --git a/internal/config/config.go b/internal/config/config.go index 6427b4f..d086bd9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -34,9 +34,14 @@ type Config struct { // ManagedAgentsConfig configures one outbound connection to Anthropic's Claude // Managed Agents "events and streaming" API. When APIKey is empty the entry is // skipped. Name distinguishes entries when more than one is configured and is -// used by the session-registration API to target a specific account. +// used by the session-registration API to target a specific account. Workspace +// is required whenever APIKey is set; it labels the Anthropic workspace this +// account belongs to for UI display and Atryum ownership metadata. Anthropic API +// keys are already scoped by Anthropic, so this value does not retarget a key to +// a different workspace. type ManagedAgentsConfig struct { Name string `toml:"name"` + Workspace string `toml:"workspace"` BaseURL string `toml:"base_url"` APIKey string `toml:"api_key"` PollIntervalMillis int `toml:"poll_interval_millis"` @@ -58,13 +63,17 @@ func (c *Config) applyManagedAgentsEnv() { if envKey == "" { return } + envWorkspace := os.Getenv("ATRYUM_MANAGED_AGENTS_WORKSPACE") switch len(c.ManagedAgents) { case 0: - c.ManagedAgents = []ManagedAgentsConfig{{APIKey: envKey}} + c.ManagedAgents = []ManagedAgentsConfig{{APIKey: envKey, Workspace: envWorkspace}} case 1: if c.ManagedAgents[0].APIKey == "" { c.ManagedAgents[0].APIKey = envKey } + if c.ManagedAgents[0].Workspace == "" { + c.ManagedAgents[0].Workspace = envWorkspace + } } } @@ -95,11 +104,12 @@ type PolicyConfig struct { } type ServerConfig struct { - ListenAddr string `toml:"listen_addr"` - PublicBaseURL string `toml:"public_base_url"` - DatabasePath string `toml:"database_path"` - DatabaseURL string `toml:"database_url"` - LogLevel string `toml:"log_level"` + ListenAddr string `toml:"listen_addr"` + PublicBaseURL string `toml:"public_base_url"` + AtryumInstance string `toml:"atryum_instance"` + DatabasePath string `toml:"database_path"` + DatabaseURL string `toml:"database_url"` + LogLevel string `toml:"log_level"` } type DefaultsConfig struct { diff --git a/internal/invocation/service.go b/internal/invocation/service.go index c2335a8..76692c3 100644 --- a/internal/invocation/service.go +++ b/internal/invocation/service.go @@ -825,15 +825,21 @@ func (s *Service) Submit(ctx context.Context, req ExternalSubmitRequest) (Invoca return InvocationResponse{}, err } ruleAction := "" + var matchedRuleID *string var resolvedAIDecision *policy.Decision var resolvedAIConfidence *float64 if s.rules != nil { if approvalRules, err := s.rules.ListApprovalRules(ctx); err == nil { for _, rule := range matchRules(approvalRules, source, req.Tool, agentRec.ID) { r := rule + if r.ID != "" { + id := r.ID + matchedRuleID = &id + } if r.Action == RuleActionAIEvaluation { d, conf := s.runAIEvaluation(ctx, &r, source, req.Tool, req.Input, agentID, agentRec) if d.Disposition == dispositionContinue { + matchedRuleID = nil slog.Info("ai_evaluation: LLM deferred to next rule; continuing rule iteration", "rule_id", r.ID, "server", source, "tool", req.Tool) continue @@ -848,6 +854,7 @@ func (s *Service) Submit(ctx context.Context, req ExternalSubmitRequest) (Invoca } } } + inv.MatchedRuleID = matchedRuleID receivedPayload := map[string]any{"tool": req.Tool, "upstream": source, "request_id": req.RequestID, "input": json.RawMessage(inv.Input), "arguments": json.RawMessage(inv.Input), "external": true} if agentID != "" { diff --git a/internal/invocation/service_test.go b/internal/invocation/service_test.go index b804631..9d0dd30 100644 --- a/internal/invocation/service_test.go +++ b/internal/invocation/service_test.go @@ -283,6 +283,49 @@ func TestSubmitPendingApprovalAutomaticallySummarizesInvocation(t *testing.T) { } } +func TestSubmitMatchesApprovalRuleByExternalToolName(t *testing.T) { + db := newSQLiteTestDB(t) + invRepo := store.NewInvocationRepo(db) + eventRepo := store.NewEventRepo(db) + rulesRepo := store.NewRulesRepo(db) + if err := rulesRepo.Create(context.Background(), store.Rule{ + ID: "rule-managed-bash", + Action: invocation.RuleActionAutoApprove, + ServerPatterns: []string{"claude-managed-agents"}, + ToolPatterns: []string{"Bash"}, + Enabled: true, + }); err != nil { + t.Fatal(err) + } + service := invocation.NewService( + invRepo, + eventRepo, + nil, + nil, + nil, + 5*time.Second, + rulesRepo, + nil, + nil, + nil, + ) + + resp, err := service.Submit(context.Background(), invocation.ExternalSubmitRequest{ + Source: "claude-managed-agents", + Tool: "Bash", + Input: map[string]any{"command": "pwd"}, + }) + if err != nil { + t.Fatal(err) + } + if resp.Status != invocation.StatusApproved { + t.Fatalf("status = %q, want approved", resp.Status) + } + if resp.MatchedRuleID == nil || *resp.MatchedRuleID != "rule-managed-bash" { + t.Fatalf("matched rule id = %v, want rule-managed-bash", resp.MatchedRuleID) + } +} + func TestSubmitAIEvaluationUsesDefaultAgentRecordForUnmappedAgentID(t *testing.T) { db := newSQLiteTestDB(t) invRepo := store.NewInvocationRepo(db) @@ -303,7 +346,6 @@ func TestSubmitAIEvaluationUsesDefaultAgentRecordForUnmappedAgentID(t *testing.T nil, 5*time.Second, rulesStoreStub{rules: []invocation.ApprovalRule{{ - ID: "rule-ai", Action: invocation.RuleActionAIEvaluation, ToolPatterns: []string{"bash"}, ModelConfigCUID: "model-ai", diff --git a/internal/managedagents/anthropic.go b/internal/managedagents/anthropic.go index 0779aa8..d7048cf 100644 --- a/internal/managedagents/anthropic.go +++ b/internal/managedagents/anthropic.go @@ -39,6 +39,282 @@ func (c *httpClient) setHeaders(req *http.Request) { req.Header.Set("content-type", "application/json") } +func (c *httpClient) ListAgents(ctx context.Context) ([]AgentInfo, error) { + values := url.Values{} + values.Set("limit", "100") + var agents []AgentInfo + for { + endpoint := c.base + "/v1/agents" + if encoded := values.Encode(); encoded != "" { + endpoint += "?" + encoded + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + c.setHeaders(req) + resp, err := c.http.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + resp.Body.Close() + return nil, fmt.Errorf("list agents: status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + var payload struct { + Data []json.RawMessage `json:"data"` + NextPage string `json:"next_page"` + } + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + resp.Body.Close() + return nil, fmt.Errorf("list agents: decode: %w", err) + } + resp.Body.Close() + for _, raw := range payload.Data { + if agent, ok := parseAgentInfo(raw); ok { + agents = append(agents, agent) + } + } + if payload.NextPage == "" { + return agents, nil + } + values.Set("page", payload.NextPage) + } +} + +func (c *httpClient) GetAgent(ctx context.Context, agentID string) (AgentInfo, error) { + endpoint := fmt.Sprintf("%s/v1/agents/%s", c.base, url.PathEscape(agentID)) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return AgentInfo{}, err + } + c.setHeaders(req) + resp, err := c.http.Do(req) + if err != nil { + return AgentInfo{}, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return AgentInfo{}, fmt.Errorf("get agent: status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + raw, err := io.ReadAll(resp.Body) + if err != nil { + return AgentInfo{}, err + } + agent, ok := parseAgentInfo(raw) + if !ok { + return AgentInfo{}, fmt.Errorf("get agent: invalid response") + } + return agent, nil +} + +func (c *httpClient) UpdateAgentMetadata(ctx context.Context, agentID string, version int, metadata map[string]*string) (AgentInfo, error) { + body, err := json.Marshal(map[string]any{"version": version, "metadata": metadata}) + if err != nil { + return AgentInfo{}, err + } + endpoint := fmt.Sprintf("%s/v1/agents/%s", c.base, url.PathEscape(agentID)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return AgentInfo{}, err + } + c.setHeaders(req) + resp, err := c.http.Do(req) + if err != nil { + return AgentInfo{}, err + } + defer resp.Body.Close() + if resp.StatusCode/100 != 2 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return AgentInfo{}, fmt.Errorf("update agent metadata: status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + raw, err := io.ReadAll(resp.Body) + if err != nil { + return AgentInfo{}, err + } + agent, ok := parseAgentInfo(raw) + if !ok { + return AgentInfo{}, fmt.Errorf("update agent metadata: invalid response") + } + return agent, nil +} + +func (c *httpClient) ListSessions(ctx context.Context, filter SessionListFilter) ([]SessionInfo, error) { + values := url.Values{} + if filter.AgentID != "" { + values.Set("agent_id", filter.AgentID) + } + values.Set("limit", "100") + values.Set("order", "desc") + var sessions []SessionInfo + for { + endpoint := c.base + "/v1/sessions" + if encoded := values.Encode(); encoded != "" { + endpoint += "?" + encoded + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + c.setHeaders(req) + resp, err := c.http.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + resp.Body.Close() + return nil, fmt.Errorf("list sessions: status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + var payload struct { + Data []json.RawMessage `json:"data"` + NextPage string `json:"next_page"` + } + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + resp.Body.Close() + return nil, fmt.Errorf("list sessions: decode: %w", err) + } + resp.Body.Close() + for _, raw := range payload.Data { + if session, ok := parseSessionInfo(raw); ok { + sessions = append(sessions, session) + } + } + if payload.NextPage == "" { + return sessions, nil + } + values.Set("page", payload.NextPage) + } +} + +func parseAgentInfo(raw json.RawMessage) (AgentInfo, bool) { + var env struct { + ID string `json:"id"` + Name string `json:"name"` + Description *string `json:"description"` + Model json.RawMessage `json:"model"` + Metadata json.RawMessage `json:"metadata"` + Version int `json:"version"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + } + if err := json.Unmarshal(raw, &env); err != nil || env.ID == "" { + return AgentInfo{}, false + } + agent := AgentInfo{ID: env.ID, Name: env.Name, Version: env.Version, Model: parseAgentModel(env.Model), Metadata: parseMetadata(env.Metadata)} + if env.Description != nil { + agent.Description = *env.Description + } + if env.CreatedAt != "" { + if t, err := time.Parse(time.RFC3339Nano, env.CreatedAt); err == nil { + agent.CreatedAt = t.UTC() + } + } + if env.UpdatedAt != "" { + if t, err := time.Parse(time.RFC3339Nano, env.UpdatedAt); err == nil { + agent.UpdatedAt = t.UTC() + } + } + return agent, true +} + +func parseMetadata(raw json.RawMessage) map[string]string { + out := map[string]string{} + if len(raw) == 0 || string(raw) == "null" { + return out + } + var values map[string]any + if err := json.Unmarshal(raw, &values); err != nil { + return out + } + for key, value := range values { + switch v := value.(type) { + case string: + out[key] = v + case bool: + if v { + out[key] = "true" + } else { + out[key] = "false" + } + case float64: + out[key] = fmt.Sprintf("%g", v) + case nil: + // Deleted/empty values are ignored. + default: + b, _ := json.Marshal(v) + out[key] = string(b) + } + } + return out +} + +func parseAgentModel(raw json.RawMessage) string { + if len(raw) == 0 || string(raw) == "null" { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + var obj struct { + ID string `json:"id"` + } + if err := json.Unmarshal(raw, &obj); err == nil { + return obj.ID + } + return "" +} + +func parseSessionInfo(raw json.RawMessage) (SessionInfo, bool) { + var env struct { + ID string `json:"id"` + Agent json.RawMessage `json:"agent"` + AgentID string `json:"agent_id"` + Title string `json:"title"` + Status string `json:"status"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + } + if err := json.Unmarshal(raw, &env); err != nil || env.ID == "" { + return SessionInfo{}, false + } + session := SessionInfo{ID: env.ID, AgentID: env.AgentID, Title: env.Title, Status: env.Status} + if session.AgentID == "" { + session.AgentID = parseSessionAgentID(env.Agent) + } + if env.CreatedAt != "" { + if t, err := time.Parse(time.RFC3339Nano, env.CreatedAt); err == nil { + session.CreatedAt = t.UTC() + } + } + if env.UpdatedAt != "" { + if t, err := time.Parse(time.RFC3339Nano, env.UpdatedAt); err == nil { + session.UpdatedAt = t.UTC() + } + } + return session, true +} + +func parseSessionAgentID(raw json.RawMessage) string { + if len(raw) == 0 || string(raw) == "null" { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + var obj struct { + ID string `json:"id"` + } + if err := json.Unmarshal(raw, &obj); err == nil { + return obj.ID + } + return "" +} + // rawEventEnvelope captures the fields the bridge keys on. The full JSON is // retained separately so downstream handling has access to everything. type rawEventEnvelope struct { @@ -70,30 +346,45 @@ func parseEnvelopeStrict(raw json.RawMessage) (RawEvent, error) { } func (c *httpClient) ListEventsSince(ctx context.Context, sessionID, afterEventID string) ([]RawEvent, error) { - endpoint := fmt.Sprintf("%s/v1/sessions/%s/events", c.base, url.PathEscape(sessionID)) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - return nil, err - } - c.setHeaders(req) - resp, err := c.http.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return nil, fmt.Errorf("list events: status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) - } - var payload struct { - Data []json.RawMessage `json:"data"` - } - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { - return nil, fmt.Errorf("list events: decode: %w", err) - } - events := make([]RawEvent, 0, len(payload.Data)) - for _, raw := range payload.Data { - events = append(events, parseEnvelope(raw)) + values := url.Values{} + values.Set("limit", "100") + values.Set("order", "asc") + var events []RawEvent + for { + endpoint := fmt.Sprintf("%s/v1/sessions/%s/events", c.base, url.PathEscape(sessionID)) + if encoded := values.Encode(); encoded != "" { + endpoint += "?" + encoded + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, err + } + c.setHeaders(req) + resp, err := c.http.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + resp.Body.Close() + return nil, fmt.Errorf("list events: status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + var payload struct { + Data []json.RawMessage `json:"data"` + NextPage string `json:"next_page"` + } + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + resp.Body.Close() + return nil, fmt.Errorf("list events: decode: %w", err) + } + resp.Body.Close() + for _, raw := range payload.Data { + events = append(events, parseEnvelope(raw)) + } + if payload.NextPage == "" { + break + } + values.Set("page", payload.NextPage) } // Drop everything up to and including the cursor so callers only see new // events. The API returns history oldest-first. diff --git a/internal/managedagents/anthropic_test.go b/internal/managedagents/anthropic_test.go index 4234bf7..c0dec01 100644 --- a/internal/managedagents/anthropic_test.go +++ b/internal/managedagents/anthropic_test.go @@ -35,6 +35,86 @@ func TestListEventsSinceFiltersByCursor(t *testing.T) { } } +func TestListEventsSincePaginates(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.URL.Query().Get("limit"); got != "100" { + t.Errorf("limit = %q", got) + } + switch r.URL.Query().Get("page") { + case "": + fmt.Fprint(w, `{"data":[{"id":"e1","type":"agent.message"}],"next_page":"p2"}`) + case "p2": + fmt.Fprint(w, `{"data":[{"id":"e2","type":"session.status_idle"}]}`) + default: + t.Fatalf("unexpected page %q", r.URL.Query().Get("page")) + } + })) + defer srv.Close() + + c := NewAnthropicHTTPClient(Config{BaseURL: srv.URL, APIKey: "k"}) + events, err := c.ListEventsSince(context.Background(), "sess_1", "") + if err != nil { + t.Fatalf("list events: %v", err) + } + if len(events) != 2 || events[0].ID != "e1" || events[1].ID != "e2" { + t.Fatalf("unexpected events: %+v", events) + } +} + +func TestListAgentsParsesModelObject(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/agents" { + t.Fatalf("unexpected path %s", r.URL.Path) + } + if got := r.Header.Get("anthropic-beta"); got != managedAgentsBeta { + t.Errorf("missing beta header, got %q", got) + } + fmt.Fprint(w, `{"data":[{"id":"agent_1","name":"Coding","description":"writes code","model":{"id":"claude-sonnet-4-6","speed":"standard"},"version":3,"created_at":"2026-06-01T00:00:00Z","updated_at":"2026-06-02T00:00:00Z"}]}`) + })) + defer srv.Close() + + c := NewAnthropicHTTPClient(Config{BaseURL: srv.URL, APIKey: "k"}) + agents, err := c.ListAgents(context.Background()) + if err != nil { + t.Fatalf("list agents: %v", err) + } + if len(agents) != 1 { + t.Fatalf("expected 1 agent, got %d", len(agents)) + } + if agents[0].ID != "agent_1" || agents[0].Model != "claude-sonnet-4-6" || agents[0].Version != 3 { + t.Fatalf("unexpected agent: %+v", agents[0]) + } +} + +func TestListAgentsPaginates(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/agents" { + t.Fatalf("unexpected path %s", r.URL.Path) + } + if got := r.URL.Query().Get("limit"); got != "100" { + t.Errorf("limit = %q", got) + } + switch r.URL.Query().Get("page") { + case "": + fmt.Fprint(w, `{"data":[{"id":"agent_1","name":"One"}],"next_page":"p2"}`) + case "p2": + fmt.Fprint(w, `{"data":[{"id":"agent_2","name":"Two"}]}`) + default: + t.Fatalf("unexpected page %q", r.URL.Query().Get("page")) + } + })) + defer srv.Close() + + c := NewAnthropicHTTPClient(Config{BaseURL: srv.URL, APIKey: "k"}) + agents, err := c.ListAgents(context.Background()) + if err != nil { + t.Fatalf("list agents: %v", err) + } + if len(agents) != 2 || agents[0].ID != "agent_1" || agents[1].ID != "agent_2" { + t.Fatalf("unexpected agents: %+v", agents) + } +} + func TestStreamEventsParsesSSE(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") diff --git a/internal/managedagents/parse.go b/internal/managedagents/parse.go index 31e6e41..dd5763e 100644 --- a/internal/managedagents/parse.go +++ b/internal/managedagents/parse.go @@ -110,7 +110,7 @@ func parseToolResult(evt RawEvent) (toolResult, bool) { } m := asObject(evt.Raw) tr := toolResult{ - ToolUseID: firstString(m, "tool_use_id", "custom_tool_use_id", "tool_use_event_id"), + ToolUseID: firstString(m, "tool_use_id", "mcp_tool_use_id", "custom_tool_use_id", "tool_use_event_id"), } if tr.ToolUseID == "" { return toolResult{}, false diff --git a/internal/managedagents/service.go b/internal/managedagents/service.go index c2a581a..60700af 100644 --- a/internal/managedagents/service.go +++ b/internal/managedagents/service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "sort" "strings" "sync" "time" @@ -25,9 +26,16 @@ type SessionStore interface { Upsert(ctx context.Context, s SessionRegistration) error Get(ctx context.Context, sessionID string) (SessionRegistration, error) List(ctx context.Context) ([]SessionRegistration, error) + Delete(ctx context.Context, sessionID string) error UpdateCursor(ctx context.Context, sessionID, lastEventID string) error } +// BindingStore supplies Atryum Agent ↔ Claude Managed Agent links for automatic +// session discovery. +type BindingStore interface { + List(ctx context.Context) ([]AgentBinding, error) +} + // Account pairs an Anthropic API client with its (defaulted) config. type account struct { client AnthropicClient @@ -46,9 +54,11 @@ type Service struct { inv InvocationGateway sessions SessionStore audit InvocationAuditStore + bindings BindingStore accounts map[string]account defaultAccount string // the sole account's name, when exactly one is configured + instanceName string rootCtx context.Context cancel context.CancelFunc @@ -58,15 +68,38 @@ type Service struct { wg sync.WaitGroup } +const ( + DefaultInstanceName = "atryum" + metadataManagedKey = "atryum_managed" + metadataInstanceKey = "atryum_instance" + metadataWorkspaceKey = "atryum_workspace" + metadataAgentCUIDKey = "atryum_agent_cuid" + metadataBindingIDKey = "atryum_binding_id" + metadataBoundAtKey = "atryum_bound_at" +) + +// AgentClaimConflictError reports an existing Atryum ownership marker on a +// Claude Managed Agent. +type AgentClaimConflictError struct { + ClaudeAgentID string + Instance string + AgentCUID string +} + +func (e *AgentClaimConflictError) Error() string { + return fmt.Sprintf("Claude managed agent %s is already linked to Atryum instance %q agent %q", e.ClaudeAgentID, e.Instance, e.AgentCUID) +} + // NewService builds a bridge over one or more Anthropic accounts. Each account's // config is normalized (name defaulted) on construction. func NewService(inv InvocationGateway, sessions SessionStore, audit InvocationAuditStore, accounts []Account) (*Service, error) { s := &Service{ - inv: inv, - sessions: sessions, - audit: audit, - accounts: make(map[string]account, len(accounts)), - watchers: make(map[string]context.CancelFunc), + inv: inv, + sessions: sessions, + audit: audit, + accounts: make(map[string]account, len(accounts)), + watchers: make(map[string]context.CancelFunc), + instanceName: DefaultInstanceName, } for _, a := range accounts { cfg := a.Config.withDefaults() @@ -83,6 +116,16 @@ func NewService(inv InvocationGateway, sessions SessionStore, audit InvocationAu return s, nil } +// SetInstanceName sets the stable Atryum identity written to Claude agent +// metadata. Empty values fall back to DefaultInstanceName. +func (s *Service) SetInstanceName(name string) { + name = strings.TrimSpace(name) + if name == "" { + name = DefaultInstanceName + } + s.instanceName = name +} + // Start resumes watching every previously-registered session. It is safe to // call once at startup. func (s *Service) Start(ctx context.Context) error { @@ -93,17 +136,28 @@ func (s *Service) Start(ctx context.Context) error { } started := 0 for _, reg := range regs { - if _, ok := s.accounts[reg.Account]; !ok { + acct, ok := s.accounts[reg.Account] + if !ok { slog.Warn("managed agents: skipping session for unknown account", "session_id", reg.SessionID, "account", reg.Account) continue } + reg = s.rewindCursorIfParkedOnPendingAction(ctx, reg, acct.client) s.startWatcher(reg) started++ } + if s.bindings != nil { + s.startDiscovery() + } slog.Info("managed agents: started", "accounts", len(s.accounts), "watched_sessions", started) return nil } +// SetBindings enables automatic discovery of Claude sessions for configured +// Atryum Agent ↔ Claude Managed Agent links. Call before Start. +func (s *Service) SetBindings(bindings BindingStore) { + s.bindings = bindings +} + // Close stops all watchers and waits for them to drain. func (s *Service) Close() error { if s.cancel != nil { @@ -144,6 +198,352 @@ func (s *Service) RegisterSession(ctx context.Context, req RegisterSessionReques return stored, nil } +// ListSessions returns every persisted Claude Managed Agents session watcher. +func (s *Service) ListSessions(ctx context.Context) ([]SessionRegistration, error) { + return s.sessions.List(ctx) +} + +// DeleteSession stops a watcher and removes its persisted registration. +func (s *Service) DeleteSession(ctx context.Context, sessionID string) error { + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + return fmt.Errorf("session_id is required") + } + if err := s.sessions.Delete(ctx, sessionID); err != nil { + return err + } + s.mu.Lock() + if cancel, ok := s.watchers[sessionID]; ok { + cancel() + delete(s.watchers, sessionID) + } + s.mu.Unlock() + return nil +} + +// ClearSessions stops and removes every persisted Claude Managed Agents session +// watcher. Linked agents will be rediscovered by the discovery loop. +func (s *Service) ClearSessions(ctx context.Context) (int, error) { + sessions, err := s.sessions.List(ctx) + if err != nil { + return 0, err + } + for _, session := range sessions { + if err := s.sessions.Delete(ctx, session.SessionID); err != nil { + return 0, err + } + } + s.mu.Lock() + for _, session := range sessions { + if cancel, ok := s.watchers[session.SessionID]; ok { + cancel() + delete(s.watchers, session.SessionID) + } + } + s.mu.Unlock() + return len(sessions), nil +} + +func (s *Service) startDiscovery() { + if s.rootCtx == nil { + s.rootCtx, s.cancel = context.WithCancel(context.Background()) + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.discoveryLoop(s.rootCtx) + }() +} + +func (s *Service) discoveryLoop(ctx context.Context) { + // Discover immediately so the UI binding becomes active without waiting for + // the first tick. + s.discoverSessions(ctx) + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.discoverSessions(ctx) + } + } +} + +func (s *Service) discoverSessions(ctx context.Context) { + bindings, err := s.bindings.List(ctx) + if err != nil { + slog.Warn("managed agents: list bindings failed", "error", err) + return + } + if len(bindings) == 0 { + return + } + regs, err := s.sessions.List(ctx) + if err != nil { + slog.Warn("managed agents: list watched sessions failed", "error", err) + return + } + known := make(map[string]bool, len(regs)) + for _, reg := range regs { + known[reg.SessionID] = true + } + for _, binding := range bindings { + if ctx.Err() != nil { + return + } + claudeAgentID := strings.TrimSpace(binding.ClaudeAgentID) + if claudeAgentID == "" { + continue + } + accountName, err := s.resolveAccount(strings.TrimSpace(binding.Account)) + if err != nil { + slog.Warn("managed agents: skipping binding with unknown account", "agent_cuid", binding.AgentCUID, "account", binding.Account, "error", err) + continue + } + acct := s.accounts[accountName] + sessions, err := acct.client.ListSessions(ctx, SessionListFilter{AgentID: claudeAgentID}) + if err != nil { + slog.Warn("managed agents: discover sessions failed", "account", accountName, "claude_agent_id", claudeAgentID, "error", err) + continue + } + for _, session := range sessions { + if session.ID == "" || known[session.ID] { + continue + } + description := strings.TrimSpace(session.Title) + if description == "" { + description = strings.TrimSpace(binding.ClaudeAgentName) + } + lastEventID, _, _ := s.discoveryTailCursor(ctx, acct.client, session.ID) + reg := SessionRegistration{ + SessionID: session.ID, + Account: accountName, + AgentID: claudeAgentID, + Description: description, + LastEventID: lastEventID, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + } + if err := s.sessions.Upsert(ctx, reg); err != nil { + slog.Warn("managed agents: persist discovered session failed", "session_id", session.ID, "account", accountName, "claude_agent_id", claudeAgentID, "error", err) + continue + } + stored, err := s.sessions.Get(ctx, session.ID) + if err != nil { + stored = reg + } + s.startWatcher(stored) + known[session.ID] = true + slog.Info("managed agents: discovered session", "session_id", session.ID, "account", accountName, "claude_agent_id", claudeAgentID) + } + } +} + +func (s *Service) rewindCursorIfParkedOnPendingAction(ctx context.Context, reg SessionRegistration, client AnthropicClient) SessionRegistration { + cursor, latestID, pending := s.discoveryTailCursor(ctx, client, reg.SessionID) + if !pending || latestID == "" || reg.LastEventID != latestID { + return reg + } + reg.LastEventID = cursor + if err := s.sessions.UpdateCursor(ctx, reg.SessionID, cursor); err != nil { + slog.Warn("managed agents: could not rewind pending-action cursor", "session_id", reg.SessionID, "error", err) + } + return reg +} + +func (s *Service) discoveryTailCursor(ctx context.Context, client AnthropicClient, sessionID string) (cursor string, latestID string, pending bool) { + events, err := client.ListEventsSince(ctx, sessionID, "") + if err != nil { + slog.Warn("managed agents: list discovery tail failed", "session_id", sessionID, "error", err) + return "", "", false + } + if len(events) == 0 { + return "", "", false + } + for i := len(events) - 1; i >= 0; i-- { + if events[i].ID != "" { + latestID = events[i].ID + break + } + } + latest := events[len(events)-1] + blockingIDs, ok := requiresAction(latest) + if !ok || len(blockingIDs) == 0 { + return latestID, latestID, false + } + needed := make(map[string]bool, len(blockingIDs)) + for _, id := range blockingIDs { + needed[id] = true + } + firstPendingIdx := -1 + for i, evt := range events { + if needed[evt.ID] && isToolUse(evt.Type) { + firstPendingIdx = i + break + } + } + if firstPendingIdx < 0 { + slog.Warn("managed agents: requires_action references missing tool-use event in discovery tail; replaying session to recover", "session_id", sessionID) + return "", latestID, true + } + for i := firstPendingIdx - 1; i >= 0; i-- { + if events[i].ID != "" { + return events[i].ID, latestID, true + } + } + return "", latestID, true +} + +// Accounts returns the configured Anthropic account names available for admin +// UI selection. +func (s *Service) Accounts() []AccountInfo { + names := make([]string, 0, len(s.accounts)) + for name := range s.accounts { + names = append(names, name) + } + sort.Strings(names) + out := make([]AccountInfo, 0, len(names)) + for _, name := range names { + out = append(out, AccountInfo{Name: name, Workspace: s.accounts[name].cfg.Workspace}) + } + return out +} + +// ListAgents lists Claude Managed Agents in a configured account. The query is +// applied locally so callers are not coupled to Anthropic's filter support. +func (s *Service) ListAgents(ctx context.Context, req ListAgentsRequest) ([]AgentInfo, error) { + accountName, err := s.resolveAccount(strings.TrimSpace(req.Account)) + if err != nil { + return nil, err + } + agents, err := s.accounts[accountName].client.ListAgents(ctx) + if err != nil { + return nil, err + } + q := strings.ToLower(strings.TrimSpace(req.Query)) + if q == "" { + return agents, nil + } + filtered := make([]AgentInfo, 0, len(agents)) + for _, a := range agents { + if strings.Contains(strings.ToLower(a.ID), q) || strings.Contains(strings.ToLower(a.Name), q) || strings.Contains(strings.ToLower(a.Description), q) { + filtered = append(filtered, a) + } + } + return filtered, nil +} + +// ClaimAgent writes Atryum ownership metadata to a Claude Managed Agent after +// checking whether another Atryum instance/agent already claimed it. +func (s *Service) ClaimAgent(ctx context.Context, req AgentClaimRequest) (AgentInfo, error) { + accountName, err := s.resolveAccount(strings.TrimSpace(req.Account)) + if err != nil { + return AgentInfo{}, err + } + agentID := strings.TrimSpace(req.ClaudeAgentID) + if agentID == "" { + return AgentInfo{}, fmt.Errorf("claude_agent_id is required") + } + agentCUID := strings.TrimSpace(req.AtryumAgentCUID) + if agentCUID == "" { + return AgentInfo{}, fmt.Errorf("atryum agent cuid is required") + } + agent, err := s.accounts[accountName].client.GetAgent(ctx, agentID) + if err != nil { + return AgentInfo{}, err + } + if agent.Metadata == nil { + agent.Metadata = map[string]string{} + } + ownerInstance := agent.Metadata[metadataInstanceKey] + ownerWorkspace := agent.Metadata[metadataWorkspaceKey] + ownerAgentCUID := agent.Metadata[metadataAgentCUIDKey] + bindingID := strings.TrimSpace(req.BindingID) + if strings.EqualFold(agent.Metadata[metadataManagedKey], "true") && !req.Force { + if ownerInstance != "" && ownerInstance != s.instanceName || ownerWorkspace != "" && ownerWorkspace != s.accounts[accountName].cfg.Workspace || ownerAgentCUID != "" && ownerAgentCUID != agentCUID { + return AgentInfo{}, &AgentClaimConflictError{ClaudeAgentID: agentID, Instance: ownerInstance, AgentCUID: ownerAgentCUID} + } + } + boundAt := time.Now().UTC().Format(time.RFC3339) + if ownerInstance == s.instanceName && ownerWorkspace == s.accounts[accountName].cfg.Workspace && ownerAgentCUID == agentCUID && (bindingID == "" || agent.Metadata[metadataBindingIDKey] == bindingID) && agent.Metadata[metadataBoundAtKey] != "" { + boundAt = agent.Metadata[metadataBoundAtKey] + } + patch := map[string]*string{ + metadataManagedKey: strPtr("true"), + metadataInstanceKey: strPtr(s.instanceName), + metadataWorkspaceKey: strPtr(s.accounts[accountName].cfg.Workspace), + metadataAgentCUIDKey: strPtr(agentCUID), + metadataBindingIDKey: strPtr(bindingID), + metadataBoundAtKey: strPtr(boundAt), + } + if metadataPatchIsNoop(agent.Metadata, patch) { + return agent, nil + } + return s.accounts[accountName].client.UpdateAgentMetadata(ctx, agentID, agent.Version, patch) +} + +// ReleaseAgent removes Atryum ownership metadata when the marker still belongs +// to this Atryum instance and binding. It is intentionally best-effort; stale +// metadata should not block local unlinking. +func (s *Service) ReleaseAgent(ctx context.Context, req AgentClaimRequest) error { + accountName, err := s.resolveAccount(strings.TrimSpace(req.Account)) + if err != nil { + return err + } + agentID := strings.TrimSpace(req.ClaudeAgentID) + if agentID == "" { + return nil + } + agent, err := s.accounts[accountName].client.GetAgent(ctx, agentID) + if err != nil { + return err + } + if agent.Metadata == nil || !strings.EqualFold(agent.Metadata[metadataManagedKey], "true") { + return nil + } + if !req.Force { + if agent.Metadata[metadataInstanceKey] != s.instanceName { + return nil + } + if agent.Metadata[metadataWorkspaceKey] != "" && agent.Metadata[metadataWorkspaceKey] != s.accounts[accountName].cfg.Workspace { + return nil + } + if req.AtryumAgentCUID != "" && agent.Metadata[metadataAgentCUIDKey] != req.AtryumAgentCUID { + return nil + } + if req.BindingID != "" && agent.Metadata[metadataBindingIDKey] != req.BindingID { + return nil + } + } + patch := map[string]*string{ + metadataManagedKey: nil, + metadataInstanceKey: nil, + metadataWorkspaceKey: nil, + metadataAgentCUIDKey: nil, + metadataBindingIDKey: nil, + metadataBoundAtKey: nil, + } + _, err = s.accounts[accountName].client.UpdateAgentMetadata(ctx, agentID, agent.Version, patch) + return err +} + +func metadataPatchIsNoop(existing map[string]string, patch map[string]*string) bool { + for key, value := range patch { + if value == nil { + if _, ok := existing[key]; ok { + return false + } + continue + } + if existing[key] != *value { + return false + } + } + return true +} + // resolveAccount validates a requested account name (or picks the sole account // when the name is empty and exactly one account is configured). func (s *Service) resolveAccount(name string) (string, error) { diff --git a/internal/managedagents/types.go b/internal/managedagents/types.go index d38d4f4..114e5b8 100644 --- a/internal/managedagents/types.go +++ b/internal/managedagents/types.go @@ -13,6 +13,7 @@ package managedagents import ( "context" "encoding/json" + "strings" "time" ) @@ -33,6 +34,7 @@ type Config struct { // Name identifies the account when more than one is configured. Empty is // allowed (and normalized to "default") for the single-account case. Name string + Workspace string BaseURL string APIKey string PollInterval time.Duration @@ -45,6 +47,7 @@ type Config struct { const DefaultAccountName = "default" func (c Config) withDefaults() Config { + c.Workspace = strings.TrimSpace(c.Workspace) if c.Name == "" { c.Name = DefaultAccountName } @@ -111,6 +114,68 @@ type SessionRegistration struct { UpdatedAt time.Time `json:"updated_at"` } +// AccountInfo describes one configured Anthropic account available to the +// managed-agents bridge. +type AccountInfo struct { + Name string `json:"name"` + Workspace string `json:"workspace"` +} + +// ListAgentsRequest selects one configured account and optionally filters the +// returned Claude Managed Agents by a case-insensitive local query. +type ListAgentsRequest struct { + Account string `json:"account,omitempty"` + Query string `json:"q,omitempty"` +} + +// AgentInfo is the subset of Anthropic's Claude Managed Agent resource that the +// Atryum admin UI needs for linking agent records. +type AgentInfo struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Model string `json:"model,omitempty"` + Version int `json:"version,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` +} + +// AgentClaimRequest marks a Claude Managed Agent as linked to an Atryum agent. +// Force permits replacing an existing Atryum ownership marker. +type AgentClaimRequest struct { + Account string + ClaudeAgentID string + AtryumAgentCUID string + BindingID string + Force bool +} + +// SessionListFilter selects Claude Managed Agent sessions to discover. +type SessionListFilter struct { + AgentID string +} + +// SessionInfo is the subset of Anthropic's session resource needed by the +// discovery reconciler. +type SessionInfo struct { + ID string + AgentID string + Title string + Status string + CreatedAt time.Time + UpdatedAt time.Time +} + +// AgentBinding links an Atryum agent record to a Claude Managed Agent. The +// managedagents package owns discovery but not the storage implementation. +type AgentBinding struct { + AgentCUID string + Account string + ClaudeAgentID string + ClaudeAgentName string +} + // toolUse holds the normalized fields parsed out of an agent.tool_use, // agent.mcp_tool_use, or agent.custom_tool_use event. type toolUse struct { @@ -125,6 +190,14 @@ type toolUse struct { // AnthropicClient is the minimal Anthropic Managed Agents surface the bridge // needs. It is an interface so tests can supply a fake. type AnthropicClient interface { + // ListAgents returns Claude Managed Agents configured in the account. + ListAgents(ctx context.Context) ([]AgentInfo, error) + // GetAgent retrieves one Claude Managed Agent, including metadata and version. + GetAgent(ctx context.Context, agentID string) (AgentInfo, error) + // UpdateAgentMetadata patches metadata on one Claude Managed Agent. + UpdateAgentMetadata(ctx context.Context, agentID string, version int, metadata map[string]*string) (AgentInfo, error) + // ListSessions returns Claude Managed Agent sessions matching the filter. + ListSessions(ctx context.Context, filter SessionListFilter) ([]SessionInfo, error) // ListEventsSince returns events newer than afterEventID, oldest first. // When afterEventID is empty it returns the full history. ListEventsSince(ctx context.Context, sessionID, afterEventID string) ([]RawEvent, error) diff --git a/internal/managedagents/watcher_test.go b/internal/managedagents/watcher_test.go index 4e31b8b..ac09aa8 100644 --- a/internal/managedagents/watcher_test.go +++ b/internal/managedagents/watcher_test.go @@ -92,6 +92,18 @@ func (c *fakeClient) ListEventsSince(ctx context.Context, sessionID, after strin defer c.mu.Unlock() return append([]RawEvent(nil), c.history...), nil } +func (c *fakeClient) ListAgents(ctx context.Context) ([]AgentInfo, error) { + return nil, nil +} +func (c *fakeClient) GetAgent(ctx context.Context, agentID string) (AgentInfo, error) { + return AgentInfo{ID: agentID, Version: 1}, nil +} +func (c *fakeClient) UpdateAgentMetadata(ctx context.Context, agentID string, version int, metadata map[string]*string) (AgentInfo, error) { + return AgentInfo{ID: agentID, Version: version + 1}, nil +} +func (c *fakeClient) ListSessions(ctx context.Context, filter SessionListFilter) ([]SessionInfo, error) { + return nil, nil +} func (c *fakeClient) StreamEvents(ctx context.Context, sessionID string) (EventStream, error) { return nil, context.Canceled } @@ -111,6 +123,16 @@ type eofStreamClient struct{} func (c eofStreamClient) ListEventsSince(ctx context.Context, sessionID, after string) ([]RawEvent, error) { return nil, nil } +func (c eofStreamClient) ListAgents(ctx context.Context) ([]AgentInfo, error) { return nil, nil } +func (c eofStreamClient) GetAgent(ctx context.Context, agentID string) (AgentInfo, error) { + return AgentInfo{ID: agentID, Version: 1}, nil +} +func (c eofStreamClient) UpdateAgentMetadata(ctx context.Context, agentID string, version int, metadata map[string]*string) (AgentInfo, error) { + return AgentInfo{ID: agentID, Version: version + 1}, nil +} +func (c eofStreamClient) ListSessions(ctx context.Context, filter SessionListFilter) ([]SessionInfo, error) { + return nil, nil +} func (c eofStreamClient) StreamEvents(ctx context.Context, sessionID string) (EventStream, error) { return eofStream{}, nil } @@ -133,6 +155,7 @@ func (s *fakeSessionStore) Get(ctx context.Context, id string) (SessionRegistrat return SessionRegistration{SessionID: id}, nil } func (s *fakeSessionStore) List(ctx context.Context) ([]SessionRegistration, error) { return nil, nil } +func (s *fakeSessionStore) Delete(ctx context.Context, id string) error { return nil } func (s *fakeSessionStore) UpdateCursor(ctx context.Context, id, lastEventID string) error { s.mu.Lock() defer s.mu.Unlock() diff --git a/internal/store/db_test.go b/internal/store/db_test.go index 68a09dc..ca9031d 100644 --- a/internal/store/db_test.go +++ b/internal/store/db_test.go @@ -46,7 +46,7 @@ func TestResolveDBTarget_SelectsSQLiteForSQLiteFileAndBarePaths(t *testing.T) { } func TestMigrationRegistryPreservesExistingVersionsAndNames(t *testing.T) { - if len(migrations) != 22 { + if len(migrations) != 23 { t.Fatalf("migration count = %d", len(migrations)) } want := []struct { @@ -75,6 +75,7 @@ func TestMigrationRegistryPreservesExistingVersionsAndNames(t *testing.T) { {20, "020_managed_agent_sessions"}, {21, "021_rename_constitution_to_charter"}, {22, "022_drop_agent_id_pattern"}, + {23, "023_managed_agent_bindings"}, } for i, w := range want { if migrations[i].Version != w.version || migrations[i].Name != w.name { @@ -85,10 +86,10 @@ func TestMigrationRegistryPreservesExistingVersionsAndNames(t *testing.T) { func TestGetPendingMigrationsUsesRegistryOrder(t *testing.T) { pending := getPendingMigrations(map[int]bool{1: true}) - if len(pending) != 21 { + if len(pending) != 22 { t.Fatalf("pending count = %d", len(pending)) } - wantVersions := []int{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22} + wantVersions := []int{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23} for i, want := range wantVersions { if pending[i].Version != want { t.Fatalf("pending[%d].Version = %d, want %d", i, pending[i].Version, want) diff --git a/internal/store/managed_agent_bindings.go b/internal/store/managed_agent_bindings.go new file mode 100644 index 0000000..f79bdfc --- /dev/null +++ b/internal/store/managed_agent_bindings.go @@ -0,0 +1,169 @@ +package store + +import ( + "context" + "database/sql" + "time" + + sq "github.com/Masterminds/squirrel" + "github.com/google/uuid" +) + +// ManagedAgentBinding links one Atryum agent record to one Claude Managed Agent +// in a configured Anthropic account. Sessions are still watched separately; this +// binding is the durable agent-to-agent mapping used by the UI and discovery. +type ManagedAgentBinding struct { + ID string + AgentCUID string + Account string + ClaudeAgentID string + ClaudeAgentName string + ClaudeAgentModel string + ClaudeAgentVersion int + CreatedAt time.Time + UpdatedAt time.Time +} + +// ManagedAgentBindingRepo provides CRUD operations for managed_agent_bindings. +type ManagedAgentBindingRepo struct { + db *sql.DB + sb sq.StatementBuilderType +} + +func NewManagedAgentBindingRepo(db *sql.DB) *ManagedAgentBindingRepo { + return NewManagedAgentBindingRepoWithDialect(db, DialectSQLite) +} + +func NewManagedAgentBindingRepoWithDialect(db *sql.DB, dialect Dialect) *ManagedAgentBindingRepo { + return &ManagedAgentBindingRepo{db: db, sb: statementBuilderForDialect(dialect)} +} + +var managedAgentBindingColumns = []string{ + "id", "agent_cuid", "account", "claude_agent_id", "claude_agent_name", + "claude_agent_model", "claude_agent_version", "created_at", "updated_at", +} + +func (r *ManagedAgentBindingRepo) ListByAgent(ctx context.Context, agentCUID string) ([]ManagedAgentBinding, error) { + query, args, err := r.sb.Select(managedAgentBindingColumns...). + From("managed_agent_bindings"). + Where(sq.Eq{"agent_cuid": agentCUID}). + OrderBy("account ASC", "claude_agent_name ASC", "claude_agent_id ASC"). + ToSql() + if err != nil { + return nil, err + } + return r.list(ctx, query, args...) +} + +func (r *ManagedAgentBindingRepo) List(ctx context.Context) ([]ManagedAgentBinding, error) { + query, args, err := r.sb.Select(managedAgentBindingColumns...). + From("managed_agent_bindings"). + OrderBy("account ASC", "claude_agent_name ASC", "claude_agent_id ASC"). + ToSql() + if err != nil { + return nil, err + } + return r.list(ctx, query, args...) +} + +func (r *ManagedAgentBindingRepo) GetByClaudeAgentID(ctx context.Context, account, claudeAgentID string) (ManagedAgentBinding, error) { + b := r.sb.Select(managedAgentBindingColumns...). + From("managed_agent_bindings"). + Where(sq.Eq{"claude_agent_id": claudeAgentID}) + if account != "" { + b = b.Where(sq.Eq{"account": account}) + } + query, args, err := b.OrderBy("updated_at DESC").Limit(1).ToSql() + if err != nil { + return ManagedAgentBinding{}, err + } + return scanManagedAgentBinding(r.db.QueryRowContext(ctx, query, args...)) +} + +// ReplaceForAgent atomically replaces all Claude Managed Agent bindings for an +// Atryum agent. This matches the edit-modal save semantics. +func (r *ManagedAgentBindingRepo) ReplaceForAgent(ctx context.Context, agentCUID string, bindings []ManagedAgentBinding) error { + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + deleteQuery, deleteArgs, err := r.sb.Delete("managed_agent_bindings").Where(sq.Eq{"agent_cuid": agentCUID}).ToSql() + if err != nil { + return err + } + if _, err := tx.ExecContext(ctx, deleteQuery, deleteArgs...); err != nil { + return err + } + + now := time.Now().UTC() + for _, b := range bindings { + if b.ID == "" { + b.ID = uuid.NewString() + } + if b.Account == "" { + b.Account = "default" + } + if b.CreatedAt.IsZero() { + b.CreatedAt = now + } + b.UpdatedAt = now + conflictDeleteQuery, conflictDeleteArgs, err := r.sb.Delete("managed_agent_bindings"). + Where(sq.Eq{"account": b.Account, "claude_agent_id": b.ClaudeAgentID}). + Where(sq.NotEq{"agent_cuid": agentCUID}). + ToSql() + if err != nil { + return err + } + if _, err := tx.ExecContext(ctx, conflictDeleteQuery, conflictDeleteArgs...); err != nil { + return err + } + query, args, err := r.sb.Insert("managed_agent_bindings"). + Columns(managedAgentBindingColumns...). + Values(b.ID, agentCUID, b.Account, b.ClaudeAgentID, b.ClaudeAgentName, b.ClaudeAgentModel, b.ClaudeAgentVersion, b.CreatedAt, b.UpdatedAt). + ToSql() + if err != nil { + return err + } + if _, err := tx.ExecContext(ctx, query, args...); err != nil { + return err + } + } + return tx.Commit() +} + +func (r *ManagedAgentBindingRepo) list(ctx context.Context, query string, args ...any) ([]ManagedAgentBinding, error) { + rows, err := r.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + var out []ManagedAgentBinding + for rows.Next() { + b, err := scanManagedAgentBinding(rows) + if err != nil { + return nil, err + } + out = append(out, b) + } + return out, rows.Err() +} + +func scanManagedAgentBinding(row interface{ Scan(dest ...any) error }) (ManagedAgentBinding, error) { + var b ManagedAgentBinding + if err := row.Scan( + &b.ID, + &b.AgentCUID, + &b.Account, + &b.ClaudeAgentID, + &b.ClaudeAgentName, + &b.ClaudeAgentModel, + &b.ClaudeAgentVersion, + &b.CreatedAt, + &b.UpdatedAt, + ); err != nil { + return ManagedAgentBinding{}, err + } + return b, nil +} diff --git a/internal/store/managed_agent_sessions.go b/internal/store/managed_agent_sessions.go index c950a64..431339e 100644 --- a/internal/store/managed_agent_sessions.go +++ b/internal/store/managed_agent_sessions.go @@ -102,6 +102,26 @@ func (r *ManagedAgentSessionRepo) List(ctx context.Context) ([]ManagedAgentSessi return out, rows.Err() } +// Delete removes a watched session registration. It returns sql.ErrNoRows when +// no session with the given ID exists. +func (r *ManagedAgentSessionRepo) Delete(ctx context.Context, sessionID string) error { + query, args, err := r.sb. + Delete("managed_agent_sessions"). + Where(sq.Eq{"session_id": sessionID}). + ToSql() + if err != nil { + return err + } + res, err := r.db.ExecContext(ctx, query, args...) + if err != nil { + return err + } + if n, err := res.RowsAffected(); err == nil && n == 0 { + return sql.ErrNoRows + } + return nil +} + // UpdateCursor advances the last_event_id watermark for a session. func (r *ManagedAgentSessionRepo) UpdateCursor(ctx context.Context, sessionID, lastEventID string) error { query, args, err := r.sb.Update("managed_agent_sessions"). diff --git a/internal/store/migrations/023_managed_agent_bindings.go b/internal/store/migrations/023_managed_agent_bindings.go new file mode 100644 index 0000000..af1b058 --- /dev/null +++ b/internal/store/migrations/023_managed_agent_bindings.go @@ -0,0 +1,41 @@ +package migrations + +func migration023() Definition { + return Definition{ + Version: 23, + Name: "023_managed_agent_bindings", + Steps: []Step{ + RawDialect("create managed_agent_bindings table", ` + CREATE TABLE IF NOT EXISTS managed_agent_bindings ( + id TEXT PRIMARY KEY, + agent_cuid TEXT NOT NULL, + account TEXT NOT NULL DEFAULT 'default', + claude_agent_id TEXT NOT NULL, + claude_agent_name TEXT NOT NULL DEFAULT '', + claude_agent_model TEXT NOT NULL DEFAULT '', + claude_agent_version INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(agent_cuid, account, claude_agent_id), + UNIQUE(account, claude_agent_id), + FOREIGN KEY(agent_cuid) REFERENCES agents(id) ON DELETE CASCADE + ) + `, ` + CREATE TABLE IF NOT EXISTS managed_agent_bindings ( + id TEXT PRIMARY KEY, + agent_cuid TEXT NOT NULL, + account TEXT NOT NULL DEFAULT 'default', + claude_agent_id TEXT NOT NULL, + claude_agent_name TEXT NOT NULL DEFAULT '', + claude_agent_model TEXT NOT NULL DEFAULT '', + claude_agent_version INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(agent_cuid, account, claude_agent_id), + UNIQUE(account, claude_agent_id), + FOREIGN KEY(agent_cuid) REFERENCES agents(id) ON DELETE CASCADE + ) + `), + }, + } +} diff --git a/internal/store/migrations/registry.go b/internal/store/migrations/registry.go index db4d400..bf2965b 100644 --- a/internal/store/migrations/registry.go +++ b/internal/store/migrations/registry.go @@ -56,5 +56,6 @@ func All() []Definition { migration020(), migration021(), migration022(), + migration023(), } } diff --git a/internal/store/sqlite_test.go b/internal/store/sqlite_test.go index b60ed89..e2df48a 100644 --- a/internal/store/sqlite_test.go +++ b/internal/store/sqlite_test.go @@ -41,12 +41,12 @@ func TestInitDB_FreshDatabase(t *testing.T) { if err := db.QueryRow(`SELECT COUNT(*) FROM schema_migrations`).Scan(&count); err != nil { t.Fatalf("count migrations: %v", err) } - if count != 22 { - t.Fatalf("expected 22 migrations, got %d", count) + if count != 23 { + t.Fatalf("expected 23 migrations, got %d", count) } // Verify all tables exist - tables := []string{"invocations", "invocation_events", "mcp_servers", "oauth_credentials", "oauth_connect_sessions", "approval_rules", "managed_agent_sessions"} + tables := []string{"invocations", "invocation_events", "mcp_servers", "oauth_credentials", "oauth_connect_sessions", "approval_rules", "managed_agent_sessions", "managed_agent_bindings"} for _, table := range tables { var name string if err := db.QueryRow(`SELECT name FROM sqlite_master WHERE type='table' AND name=?`, table).Scan(&name); err != nil { @@ -70,8 +70,8 @@ func TestInitDB_Idempotent(t *testing.T) { if err := db.QueryRow(`SELECT COUNT(*) FROM schema_migrations`).Scan(&count); err != nil { t.Fatalf("count migrations: %v", err) } - if count != 22 { - t.Fatalf("expected 22 migrations after double init, got %d", count) + if count != 23 { + t.Fatalf("expected 23 migrations after double init, got %d", count) } } diff --git a/ui/src/api/AtryumAPI.ts b/ui/src/api/AtryumAPI.ts index ae2364e..1ad1bfe 100644 --- a/ui/src/api/AtryumAPI.ts +++ b/ui/src/api/AtryumAPI.ts @@ -352,6 +352,7 @@ export interface Agent { name: string; description: string; agent_ids: string[]; + claude_managed_agents?: ClaudeManagedAgentBinding[]; enabled: boolean; synced_at: string; /** True when this agent originated from a ValidMind sync and cannot be deleted manually. */ @@ -366,6 +367,8 @@ export interface AgentCreateInput { enabled: boolean; agent_ids?: string[]; charter?: string; + claude_managed_agents?: ClaudeManagedAgentBinding[]; + force_claude_managed_agent_connect?: boolean; } export interface AgentUpdateInput { @@ -374,6 +377,42 @@ export interface AgentUpdateInput { enabled: boolean; agent_ids?: string[]; charter?: string; + claude_managed_agents?: ClaudeManagedAgentBinding[]; + force_claude_managed_agent_connect?: boolean; +} + +export interface ClaudeManagedAgentBinding { + id?: string; + account: string; + claude_agent_id: string; + claude_agent_name?: string; + claude_agent_model?: string; + claude_agent_version?: number; +} + +export interface ClaudeManagedAgent { + id: string; + name: string; + description?: string; + model?: string; + version?: number; + created_at?: string; + updated_at?: string; +} + +export interface ClaudeManagedAgentAccount { + name: string; + workspace: string; +} + +export interface ManagedAgentSession { + session_id: string; + account: string; + agent_id?: string; + description?: string; + last_event_id?: string; + created_at: string; + updated_at: string; } export const agentsApi = { @@ -404,6 +443,39 @@ export const agentsApi = { sync: async (): Promise => { await atryumApi.post('/api/v1/admin/agents/sync'); }, + + managedAgentAccounts: async (): Promise<{ items: ClaudeManagedAgentAccount[] }> => { + const { data } = await atryumApi.get('/api/v1/admin/managed-agents/accounts'); + return data; + }, + + managedAgentSessions: async (): Promise<{ items: ManagedAgentSession[] }> => { + const { data } = await atryumApi.get('/api/v1/admin/managed-agents/sessions'); + return data; + }, + + deleteManagedAgentSession: async (sessionID: string): Promise => { + await atryumApi.delete( + `/api/v1/admin/managed-agents/sessions/${encodeURIComponent(sessionID)}`, + ); + }, + + clearManagedAgentSessions: async (): Promise<{ deleted: number }> => { + const { data } = await atryumApi.delete('/api/v1/admin/managed-agents/sessions'); + return data; + }, + + managedAgents: async ( + account?: string, + q?: string, + ): Promise<{ items: ClaudeManagedAgent[] }> => { + const params = new URLSearchParams(); + if (account) params.set('account', account); + if (q) params.set('q', q); + const suffix = params.toString() ? `?${params.toString()}` : ''; + const { data } = await atryumApi.get(`/api/v1/admin/managed-agents/agents${suffix}`); + return data; + }, }; // ─── Settings ───────────────────────────────────────────────────────────────── diff --git a/ui/src/pages/Agents.tsx b/ui/src/pages/Agents.tsx index 0002ef8..105e7fb 100644 --- a/ui/src/pages/Agents.tsx +++ b/ui/src/pages/Agents.tsx @@ -1,4 +1,4 @@ -import React, { useCallback, useState } from 'react'; +import React, { useCallback, useMemo, useState } from 'react'; import { Alert, AlertDescription, @@ -34,16 +34,28 @@ import { VStack, useDisclosure, } from '@chakra-ui/react'; -import { CreatableSelect } from 'chakra-react-select'; +import { CreatableSelect, Select } from 'chakra-react-select'; import { CpuChipIcon } from '@heroicons/react/24/outline'; +import { useQuery } from 'react-query'; import { ContentPageTitle } from '../components/Layout'; import { useAgents, useCreateAgent, useUpdateAgent, useDeleteAgent } from '../hooks/useAgents'; import { useSettings } from '../hooks/useSettings'; -import type { Agent, AgentCreateInput, AgentUpdateInput } from '../api/AtryumAPI'; +import type { + Agent, + AgentCreateInput, + AgentUpdateInput, + ClaudeManagedAgent, + ClaudeManagedAgentBinding, +} from '../api/AtryumAPI'; import { agentsApi } from '../api/AtryumAPI'; type SelectOption = { value: string; label: string }; +type ManagedAgentOption = { + value: string; + label: string; + binding: ClaudeManagedAgentBinding; +}; const toOptions = (ids: string[]): SelectOption[] => ids.map((id) => ({ value: id, label: id })); const fromOptions = (opts: readonly SelectOption[]): string[] => @@ -61,12 +73,12 @@ const formatDate = (iso: string): string => { }; const errorMessage = (err: unknown, fallback: string): string => { - // Prefer the API response body: { error: { message: "..." } } - if (typeof err === 'object' && err !== null) { - const apiMsg = (err as { response?: { data?: { error?: { message?: unknown } } } }).response - ?.data?.error?.message; - if (typeof apiMsg === 'string' && apiMsg) return apiMsg; - } + // Prefer the API response body: { error: { message: "..." } } + if (typeof err === 'object' && err !== null) { + const apiMsg = (err as { response?: { data?: { error?: { message?: unknown } } } }).response + ?.data?.error?.message; + if (typeof apiMsg === 'string' && apiMsg) return apiMsg; + } if (err instanceof Error) return err.message; if (typeof err === 'object' && err !== null && 'message' in err) { const msg = (err as { message: unknown }).message; @@ -75,6 +87,36 @@ const errorMessage = (err: unknown, fallback: string): string => { return fallback; }; +const statusCode = (err: unknown): number | undefined => { + if (typeof err !== 'object' || err === null || !('response' in err)) return undefined; + return (err as { response?: { status?: number } }).response?.status; +}; + +const bindingKey = (binding: ClaudeManagedAgentBinding): string => + `${binding.account || 'default'}:${binding.claude_agent_id}`; + +const bindingLabel = (binding: ClaudeManagedAgentBinding): string => { + const name = binding.claude_agent_name || binding.claude_agent_id; + return `${name} (${binding.claude_agent_id})`; +}; + +const toManagedAgentBinding = ( + agent: ClaudeManagedAgent, + account: string, +): ClaudeManagedAgentBinding => ({ + account: account || 'default', + claude_agent_id: agent.id, + claude_agent_name: agent.name, + claude_agent_model: agent.model, + claude_agent_version: agent.version, +}); + +const toManagedAgentOption = (binding: ClaudeManagedAgentBinding): ManagedAgentOption => ({ + value: bindingKey(binding), + label: bindingLabel(binding), + binding, +}); + // ─── Create Modal ───────────────────────────────────────────────────────────── type CreateAgentModalProps = { @@ -238,34 +280,83 @@ type EditAgentModalProps = { const EditAgentModal: React.FC = ({ agent, isOpen, onClose }) => { const [name, setName] = useState(agent.name); const [description, setDescription] = useState(agent.description ?? ''); - const [charter, setCharter] = useState(agent.charter ?? ''); - const [enabled, setEnabled] = useState(agent.enabled); - const [agentIDs, setAgentIDs] = useState(agent.agent_ids); - const [statusMsg, setStatusMsg] = useState(null); - - const updateMutation = useUpdateAgent(); - const deleteMutation = useDeleteAgent(); - const { data: agentsData } = useAgents(); - - const handleUpdate = async () => { - if (!name.trim()) { - setStatusMsg({ text: 'Name is required.', isError: true }); - return; + const [charter, setCharter] = useState(agent.charter ?? ''); + const [enabled, setEnabled] = useState(agent.enabled); + const [agentIDs, setAgentIDs] = useState(agent.agent_ids); + const [managedBindings, setManagedBindings] = useState( + agent.claude_managed_agents ?? [], + ); + const [managedAccount, setManagedAccount] = useState( + agent.claude_managed_agents?.[0]?.account ?? 'default', + ); + const [managedSearch, setManagedSearch] = useState(''); + const [forceManagedConnect, setForceManagedConnect] = useState(false); + const [statusMsg, setStatusMsg] = useState(null); + + const updateMutation = useUpdateAgent(); + const deleteMutation = useDeleteAgent(); + const { data: agentsData } = useAgents(); + const accountsQuery = useQuery( + ['claude-managed-agent-accounts'], + () => agentsApi.managedAgentAccounts(), + { enabled: isOpen, refetchOnWindowFocus: false, retry: false }, + ); + const accountItems = accountsQuery.data?.items ?? []; + const selectedAccount = accountItems.some((account) => account.name === managedAccount) + ? managedAccount + : accountItems[0]?.name || managedAccount || 'default'; + const managedAgentsQuery = useQuery( + ['claude-managed-agents', selectedAccount, managedSearch], + () => agentsApi.managedAgents(selectedAccount, managedSearch), + { + enabled: isOpen && !accountsQuery.isError && !accountsQuery.isLoading, + refetchOnWindowFocus: false, + retry: false, + }, + ); + const managedAgentsUnavailable = accountsQuery.isError && statusCode(accountsQuery.error) === 501; + const managedAgentOptions = useMemo(() => { + const byKey = new Map(); + for (const binding of managedBindings) { + byKey.set(bindingKey(binding), toManagedAgentOption(binding)); } - const conflicts = agentIDs.flatMap((id) => { - const owner = agentsData?.items.find((a) => a.cuid !== agent.cuid && a.agent_ids.includes(id)); - return owner ? [`${id} is already in use by "${owner.name}"`] : []; + for (const managedAgent of managedAgentsQuery.data?.items ?? []) { + const binding = toManagedAgentBinding(managedAgent, selectedAccount); + byKey.set(bindingKey(binding), toManagedAgentOption(binding)); + } + return Array.from(byKey.values()); + }, [managedAgentsQuery.data?.items, managedBindings, selectedAccount]); + const selectedManagedAgentOptions = managedBindings.map(toManagedAgentOption); + + const handleUpdate = async () => { + if (!name.trim()) { + setStatusMsg({ text: 'Name is required.', isError: true }); + return; + } + const conflicts = agentIDs.flatMap((id) => { + const owner = agentsData?.items.find((a) => a.cuid !== agent.cuid && a.agent_ids.includes(id)); + return owner ? [`${id} is already in use by "${owner.name}"`] : []; }); if (conflicts.length > 0) { setStatusMsg({ text: 'Agent ID(s) already in use by another agent:', lines: conflicts, isError: true, - }); - return; - } - const input: AgentUpdateInput = { name, description, enabled, agent_ids: agentIDs, charter }; - try { + }); + return; + } + const input: AgentUpdateInput = { + name, + description, + enabled, + agent_ids: agentIDs, + charter, + }; + if (!managedAgentsUnavailable) { + input.claude_managed_agents = managedBindings; + input.force_claude_managed_agent_connect = forceManagedConnect; + } + try { await updateMutation.mutateAsync({ cuid: agent.cuid, input }); setStatusMsg(null); onClose(); @@ -289,7 +380,7 @@ const EditAgentModal: React.FC = ({ agent, isOpen, onClose const isBusy = updateMutation.isLoading || deleteMutation.isLoading; return ( - + {agent.name} @@ -381,6 +472,114 @@ const EditAgentModal: React.FC = ({ agent, isOpen, onClose /> + {!managedAgentsUnavailable && } + + {!managedAgentsUnavailable && ( + + Claude Managed Agents + + Link Anthropic-hosted Claude agents to this Atryum agent. Session + discovery will use these links. + + {accountsQuery.isError ? ( + + + + Claude Managed Agents bridge is not configured. + + + ) : ( + + {accountItems.length > 1 && ( +