diff --git a/.semgrepignore b/.semgrepignore new file mode 100644 index 0000000000..74789df669 --- /dev/null +++ b/.semgrepignore @@ -0,0 +1,4 @@ +# This file intentionally uses text/template (not html/template) because it +# generates LLM prompts (plain text), not HTML for web browsers. +# The XSS rule does not apply here. +internal/agent/prompt/prompt.go diff --git a/internal/agent/agent_params_test.go b/internal/agent/agent_params_test.go new file mode 100644 index 0000000000..51c15e9945 --- /dev/null +++ b/internal/agent/agent_params_test.go @@ -0,0 +1,46 @@ +package agent + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +// AgentParams must decode the subagent_type field so UI renderers can label +// the row with the specific subagent name. +func TestAgentParams_DecodesSubagentType(t *testing.T) { + t.Parallel() + + input := []byte(`{"subagent_type":"code-reviewer","prompt":"review this"}`) + + var params AgentParams + require.NoError(t, json.Unmarshal(input, ¶ms)) + require.Equal(t, "code-reviewer", params.SubagentType) + require.Equal(t, "review this", params.Prompt) +} + +func TestAgentParams_OmitsSubagentTypeWhenAbsent(t *testing.T) { + t.Parallel() + + input := []byte(`{"prompt":"search for things"}`) + + var params AgentParams + require.NoError(t, json.Unmarshal(input, ¶ms)) + require.Empty(t, params.SubagentType) + require.Equal(t, "search for things", params.Prompt) +} + +// AgentParams and AgentDispatchParams must share a wire-compatible shape so +// historical tool-call inputs decode cleanly under both types. +func TestAgentParams_WireCompatibleWithDispatchParams(t *testing.T) { + t.Parallel() + + wire, err := json.Marshal(AgentDispatchParams{SubagentType: "tester", Prompt: "x"}) + require.NoError(t, err) + + var ap AgentParams + require.NoError(t, json.Unmarshal(wire, &ap)) + require.Equal(t, "tester", ap.SubagentType) + require.Equal(t, "x", ap.Prompt) +} diff --git a/internal/agent/agent_tool.go b/internal/agent/agent_tool.go index 5be58c842e..334359dc2c 100644 --- a/internal/agent/agent_tool.go +++ b/internal/agent/agent_tool.go @@ -3,44 +3,145 @@ package agent import ( "context" _ "embed" + "encoding/json" "errors" + "fmt" + "strings" "charm.land/fantasy" "github.com/charmbracelet/crush/internal/agent/prompt" "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/subagents" ) //go:embed templates/agent_tool.md var agentToolDescription string +// AgentParams is the shape consumed by UI tool-call renderers when displaying +// historical agent tool invocations. New tool-call inputs decode with +// AgentDispatchParams; AgentParams stays wire-compatible so older inputs still +// decode cleanly. type AgentParams struct { - Prompt string `json:"prompt" description:"The task for the agent to perform"` + SubagentType string `json:"subagent_type,omitempty"` + Prompt string `json:"prompt" description:"The task for the agent to perform"` +} + +// AgentDispatchParams is the input to the dispatcher agent tool. +type AgentDispatchParams struct { + SubagentType string `json:"subagent_type,omitempty"` + Prompt string `json:"prompt"` } const ( AgentToolName = "agent" ) +// dispatcherTool implements fantasy.AgentTool with a dynamically-built schema. +type dispatcherTool struct { + info fantasy.ToolInfo + dispatch func(ctx context.Context, params AgentDispatchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) + providerOpts fantasy.ProviderOptions +} + +func (d *dispatcherTool) Info() fantasy.ToolInfo { return d.info } +func (d *dispatcherTool) ProviderOptions() fantasy.ProviderOptions { return d.providerOpts } +func (d *dispatcherTool) SetProviderOptions(opts fantasy.ProviderOptions) { d.providerOpts = opts } +func (d *dispatcherTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) { + var params AgentDispatchParams + if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { + return fantasy.NewTextErrorResponse("invalid parameters: " + err.Error()), nil + } + return d.dispatch(ctx, params, call) +} + +// findSubagentByName returns the active subagent with the given name, or nil +// when none matches. +func findSubagentByName(active []*subagents.Subagent, name string) *subagents.Subagent { + for _, sa := range active { + if sa.Name == name { + return sa + } + } + return nil +} + +// subagentSessionSetup returns a SessionSetup callback that applies the +// subagent's permission mode to the freshly-created sub-session. Returns +// nil when no setup is needed. +func (c *coordinator) subagentSessionSetup(sa *subagents.Subagent) func(sessionID string) { + if sa.PermissionMode != subagents.PermissionModeBypassPermissions { + return nil + } + return func(sessionID string) { + c.permissions.AutoApproveSession(sessionID) + } +} + +// buildAgentDispatchInfo builds the ToolInfo for the agent dispatcher tool with +// a dynamic subagent_type enum derived from the currently active subagents. +func buildAgentDispatchInfo(activeSubagents []*subagents.Subagent) fantasy.ToolInfo { + enumValues := []string{"task"} + for _, sa := range activeSubagents { + enumValues = append(enumValues, sa.Name) + } + + typeDesc := `The type of agent to use. Use "task" for general search and research tasks.` + if len(activeSubagents) > 0 { + lines := make([]string, 0, len(activeSubagents)) + for _, sa := range activeSubagents { + lines = append(lines, fmt.Sprintf("- %s: %s", sa.Name, sa.Description)) + } + typeDesc += "\n\nAvailable specialized agents:\n" + strings.Join(lines, "\n") + } + + return fantasy.ToolInfo{ + Name: AgentToolName, + Description: agentToolDescription, + Parameters: map[string]any{ + "subagent_type": map[string]any{ + "type": "string", + "enum": enumValues, + "description": typeDesc, + }, + "prompt": map[string]any{ + "type": "string", + "description": "The task for the agent to perform", + }, + }, + Required: []string{"prompt"}, + Parallel: true, + } +} + func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) { - agentCfg, ok := c.cfg.Config().Agents[config.AgentTask] + taskCfg, ok := c.cfg.Config().Agents[config.AgentTask] if !ok { return nil, errors.New("task agent not configured") } - prompt, err := taskPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir())) + coderCfg, ok := c.cfg.Config().Agents[config.AgentCoder] + if !ok { + return nil, errors.New("coder agent not configured") + } + taskPr, err := taskPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir())) if err != nil { return nil, err } - - agent, err := c.buildAgent(ctx, prompt, agentCfg, true) + taskAgent, err := c.buildAgent(ctx, taskPr, taskCfg, true, subagentModel{}) if err != nil { return nil, err } - return fantasy.NewParallelAgentTool( - AgentToolName, - agentToolDescription, - func(ctx context.Context, params AgentParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) { + + // The subagent_type enum is a point-in-time snapshot baked into the tool + // schema; a Library reload won't refresh it. Dispatch lookups use the live + // list (activeSubagentsList) so a since-removed name fails cleanly and a + // newly added one still resolves — the enum is advisory only. + info := buildAgentDispatchInfo(c.activeSubagentsList()) + + return &dispatcherTool{ + info: info, + dispatch: func(ctx context.Context, params AgentDispatchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) { if params.Prompt == "" { return fantasy.NewTextErrorResponse("prompt is required"), nil } @@ -49,20 +150,57 @@ func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) if sessionID == "" { return fantasy.ToolResponse{}, errors.New("session id missing from context") } - agentMessageID := tools.GetMessageFromContext(ctx) if agentMessageID == "" { return fantasy.ToolResponse{}, errors.New("agent message id missing from context") } + subagentType := params.SubagentType + if subagentType == "" || subagentType == config.AgentTask { + return c.runSubAgent(ctx, subAgentParams{ + Agent: taskAgent, + SessionID: sessionID, + AgentMessageID: agentMessageID, + ToolCallID: call.ID, + Prompt: params.Prompt, + SessionTitle: "New Agent Session", + AgentName: config.AgentTask, + AgentColor: subagents.AutoColor(config.AgentTask), + AgentModel: taskAgent.Model().ModelCfg.Model, + }) + } + + sa := findSubagentByName(c.activeSubagentsList(), subagentType) + if sa == nil { + return fantasy.NewTextErrorResponse(fmt.Sprintf("unknown subagent type: %q", subagentType)), nil + } + + agentCfg := sa.ToConfigAgent(coderCfg) + // Config-driven setup failures (prompt build, model/provider that + // passed discovery but fails at build) are surfaced as tool-error + // responses so the parent agent can report them and continue; a + // bare error would abort the whole turn. + subPr, err := subagentPrompt(sa, c.activeSkills, prompt.WithWorkingDir(c.cfg.WorkingDir())) + if err != nil { + return fantasy.NewTextErrorResponse(fmt.Sprintf("build subagent prompt %q: %v", sa.Name, err)), nil + } + agent, err := c.buildAgent(ctx, subPr, agentCfg, true, subagentModel{Effort: sa.Effort, Model: sa.Model, Provider: sa.Provider}) + if err != nil { + return fantasy.NewTextErrorResponse(fmt.Sprintf("build subagent %q: %v", sa.Name, err)), nil + } + return c.runSubAgent(ctx, subAgentParams{ Agent: agent, SessionID: sessionID, AgentMessageID: agentMessageID, ToolCallID: call.ID, Prompt: params.Prompt, - SessionTitle: "New Agent Session", + SessionTitle: sa.Name + " Agent Session", + SessionSetup: c.subagentSessionSetup(sa), + AgentName: sa.Name, + AgentColor: sa.ResolvedColor(), + AgentModel: agent.Model().ModelCfg.Model, }) }, - ), nil + }, nil } diff --git a/internal/agent/agent_tool_test.go b/internal/agent/agent_tool_test.go new file mode 100644 index 0000000000..017811a9a2 --- /dev/null +++ b/internal/agent/agent_tool_test.go @@ -0,0 +1,339 @@ +package agent + +import ( + "context" + "encoding/json" + "testing" + + "charm.land/catwalk/pkg/catwalk" + "charm.land/fantasy" + "charm.land/fantasy/providers/openaicompat" + "github.com/charmbracelet/crush/internal/agent/tools" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/subagents" + "github.com/stretchr/testify/require" +) + +func TestBuildAgentDispatchInfo_NoSubagents(t *testing.T) { + t.Parallel() + + info := buildAgentDispatchInfo(nil) + + require.Equal(t, "agent", info.Name) + require.True(t, info.Parallel) + require.Contains(t, info.Required, "prompt") + + subagentTypeParam, ok := info.Parameters["subagent_type"] + require.True(t, ok, "Parameters should have a subagent_type key") + + paramMap, ok := subagentTypeParam.(map[string]any) + require.True(t, ok, "subagent_type parameter should be a map[string]any") + + enum, ok := paramMap["enum"] + require.True(t, ok, "subagent_type parameter should have an enum key") + + enumSlice, ok := enum.([]string) + require.True(t, ok, "enum should be a []string") + require.Contains(t, enumSlice, "task") +} + +func TestBuildAgentDispatchInfo_WithSubagents(t *testing.T) { + t.Parallel() + + activeSubagents := []*subagents.Subagent{ + {Name: "code-reviewer", Description: "Reviews code"}, + {Name: "tester", Description: "Writes tests"}, + } + + info := buildAgentDispatchInfo(activeSubagents) + + subagentTypeParam, ok := info.Parameters["subagent_type"] + require.True(t, ok, "Parameters should have a subagent_type key") + + paramMap, ok := subagentTypeParam.(map[string]any) + require.True(t, ok, "subagent_type parameter should be a map[string]any") + + enum, ok := paramMap["enum"] + require.True(t, ok, "subagent_type parameter should have an enum key") + + enumSlice, ok := enum.([]string) + require.True(t, ok, "enum should be a []string") + require.Contains(t, enumSlice, "task") + require.Contains(t, enumSlice, "code-reviewer") + require.Contains(t, enumSlice, "tester") + + // subagent descriptions should appear in the subagent_type parameter description + desc, ok := paramMap["description"] + require.True(t, ok, "subagent_type parameter should have a description key") + descStr, ok := desc.(string) + require.True(t, ok, "description should be a string") + require.Contains(t, descStr, "Reviews code") + require.Contains(t, descStr, "Writes tests") +} + +func TestBuildAgentDispatchInfo_PromptRequired(t *testing.T) { + t.Parallel() + + info := buildAgentDispatchInfo(nil) + + require.Contains(t, info.Required, "prompt") + + // subagent_type is optional — should NOT appear in Required + for _, r := range info.Required { + require.NotEqual(t, "subagent_type", r, "subagent_type should not be required") + } +} + +// dispatcherTool tests — exercise the struct's Run and Info methods without a +// full coordinator. The dispatch closure is injected so no provider setup needed. + +func TestDispatcherTool_Info_ReturnsBuildInfo(t *testing.T) { + t.Parallel() + + info := buildAgentDispatchInfo([]*subagents.Subagent{{Name: "my-agent", Description: "Does stuff"}}) + dt := &dispatcherTool{info: info} + + got := dt.Info() + require.Equal(t, "agent", got.Name) + require.True(t, got.Parallel) +} + +func TestDispatcherTool_Run_ParsesJSONAndCallsDispatch(t *testing.T) { + t.Parallel() + + var capturedParams AgentDispatchParams + dt := &dispatcherTool{ + info: buildAgentDispatchInfo(nil), + dispatch: func(_ context.Context, params AgentDispatchParams, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + capturedParams = params + return fantasy.NewTextResponse("ok"), nil + }, + } + + input, _ := json.Marshal(AgentDispatchParams{SubagentType: "my-agent", Prompt: "do the thing"}) + resp, err := dt.Run(context.Background(), fantasy.ToolCall{Input: string(input)}) + + require.NoError(t, err) + require.False(t, resp.IsError) + require.Equal(t, "my-agent", capturedParams.SubagentType) + require.Equal(t, "do the thing", capturedParams.Prompt) +} + +func TestDispatcherTool_Run_InvalidJSON_ReturnsErrorResponse(t *testing.T) { + t.Parallel() + + dt := &dispatcherTool{ + info: buildAgentDispatchInfo(nil), + dispatch: func(_ context.Context, _ AgentDispatchParams, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + t.Fatal("dispatch should not be called for invalid JSON") + return fantasy.ToolResponse{}, nil + }, + } + + resp, err := dt.Run(context.Background(), fantasy.ToolCall{Input: "not-valid-json{"}) + + require.NoError(t, err) // errors are surfaced as error responses, not Go errors + require.True(t, resp.IsError) +} + +func TestDispatcherTool_Run_EmptySubagentType_RoutesToTask(t *testing.T) { + t.Parallel() + + var capturedParams AgentDispatchParams + dt := &dispatcherTool{ + info: buildAgentDispatchInfo(nil), + dispatch: func(_ context.Context, params AgentDispatchParams, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + capturedParams = params + return fantasy.NewTextResponse("ok"), nil + }, + } + + input, _ := json.Marshal(AgentDispatchParams{Prompt: "search for something"}) + _, err := dt.Run(context.Background(), fantasy.ToolCall{Input: string(input)}) + + require.NoError(t, err) + require.Empty(t, capturedParams.SubagentType) // dispatch receives params as-is; routing is in the closure +} + +func TestDispatcherTool_ProviderOptions_RoundTrip(t *testing.T) { + t.Parallel() + + dt := &dispatcherTool{info: buildAgentDispatchInfo(nil)} + require.Nil(t, dt.ProviderOptions()) + + opts := fantasy.ProviderOptions{} + dt.SetProviderOptions(opts) + require.NotNil(t, dt.ProviderOptions()) +} + +func TestFindSubagentByName(t *testing.T) { + t.Parallel() + + active := []*subagents.Subagent{ + {Name: "alpha"}, + {Name: "beta"}, + } + + require.NotNil(t, findSubagentByName(active, "alpha")) + require.Equal(t, "alpha", findSubagentByName(active, "alpha").Name) + require.Equal(t, "beta", findSubagentByName(active, "beta").Name) + require.Nil(t, findSubagentByName(active, "missing")) + require.Nil(t, findSubagentByName(active, "")) + require.Nil(t, findSubagentByName(nil, "alpha")) +} + +// TestDispatcherTool_Run_UnknownSubagent_ReturnsErrorResponse exercises the +// dispatcher routing for a subagent_type not in the active list. The closure +// here mirrors the lookup performed by (*coordinator).agentTool. +func TestDispatcherTool_Run_UnknownSubagent_ReturnsErrorResponse(t *testing.T) { + t.Parallel() + + active := []*subagents.Subagent{ + {Name: "code-reviewer", Description: "ok"}, + } + + dt := &dispatcherTool{ + info: buildAgentDispatchInfo(active), + dispatch: func(_ context.Context, params AgentDispatchParams, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + sa := findSubagentByName(active, params.SubagentType) + if sa == nil { + return fantasy.NewTextErrorResponse("unknown subagent type: \"" + params.SubagentType + "\""), nil + } + return fantasy.NewTextResponse("would have run " + sa.Name), nil + }, + } + + input, _ := json.Marshal(AgentDispatchParams{SubagentType: "imaginary", Prompt: "do thing"}) + resp, err := dt.Run(context.Background(), fantasy.ToolCall{Input: string(input)}) + + require.NoError(t, err) + require.True(t, resp.IsError) +} + +// recordingPermissions stubs permission.Service to capture +// AutoApproveSession calls for subagent dispatch tests. All other methods +// are no-ops or return zero values. +type recordingPermissions struct { + permission.Service + autoApproved []string +} + +func (r *recordingPermissions) AutoApproveSession(sessionID string) { + r.autoApproved = append(r.autoApproved, sessionID) +} + +func TestSubagentSessionSetup(t *testing.T) { + t.Parallel() + + t.Run("nil_when_no_bypass", func(t *testing.T) { + t.Parallel() + c := &coordinator{} + require.Nil(t, c.subagentSessionSetup(&subagents.Subagent{Name: "a"})) + require.Nil(t, c.subagentSessionSetup(&subagents.Subagent{Name: "a", PermissionMode: subagents.PermissionModeDefault})) + }) + + t.Run("bypass_calls_auto_approve", func(t *testing.T) { + t.Parallel() + rec := &recordingPermissions{} + c := &coordinator{permissions: rec} + sa := &subagents.Subagent{Name: "a", PermissionMode: subagents.PermissionModeBypassPermissions} + + setup := c.subagentSessionSetup(sa) + require.NotNil(t, setup) + + setup("session-123") + require.Equal(t, []string{"session-123"}, rec.autoApproved) + }) +} + +// TestAgentTool_SubagentBuildFailure_SurfacedAsToolError verifies that when a +// named subagent fails to build (because its model: names a model no provider +// offers), the dispatcher returns a ToolResponse with IsError==true and a nil +// Go error. A nil Go error is critical: fantasy treats a non-nil error as a +// hard abort of the whole agent turn, whereas an error response lets the +// parent model see the failure and continue. +func TestAgentTool_SubagentBuildFailure_SurfacedAsToolError(t *testing.T) { + t.Parallel() + + env := testEnv(t) + + // Build a minimal offline config with one provider and one model, mirroring + // agenttest.NewCoordinator so no network call is needed. + cfg, err := config.Init(env.workingDir, "", false) + require.NoError(t, err) + + const ( + providerID = "test-openai-compat" + modelID = "test-model" + ) + cfg.Config().Providers.Set(providerID, config.ProviderConfig{ + ID: providerID, + Name: "Test", + Type: openaicompat.Name, + BaseURL: "http://127.0.0.1:0/v1", + APIKey: "test", + Models: []catwalk.Model{{ID: modelID, DefaultMaxTokens: 4096}}, + }) + selected := config.SelectedModel{Provider: providerID, Model: modelID} + cfg.Config().Models[config.SelectedModelTypeLarge] = selected + cfg.Config().Models[config.SelectedModelTypeSmall] = selected + cfg.SetupAgents() + + // Clear AllowedTools on both agents so buildTools stays cheap and offline. + for _, agentID := range []string{config.AgentCoder, config.AgentTask} { + a := cfg.Config().Agents[agentID] + a.AllowedTools = nil + cfg.Config().Agents[agentID] = a + } + + c, err := NewCoordinator( + t.Context(), + cfg, + env.sessions, + env.messages, + permission.NewPermissionService(env.workingDir, true, nil), + nil, // history + nil, // filetracker + nil, // lspManager + nil, // notify + nil, // runComplete + nil, // skillsMgr + nil, // subagentsMgr + nil, // runtime + ) + require.NoError(t, err) + + // Type-assert to *coordinator so we can access unexported fields and methods. + coord := c.(*coordinator) + + // Inject a broken subagent whose model is not offered by any provider. + // activeSubagentsList falls back to activeSubagents when subagentsMgr is nil. + coord.activeSubagents = []*subagents.Subagent{ + {Name: "broken", Description: "intentionally broken", Model: "no-such-model"}, + } + + // Retrieve the real dispatcher tool built by agentTool. + tool, err := coord.agentTool(t.Context()) + require.NoError(t, err) + + dt := tool.(*dispatcherTool) + + // Inject session and message IDs into context; the dispatch closure returns + // a hard error when either is absent, which is a different code path. + ctx := context.WithValue(t.Context(), tools.SessionIDContextKey, "sess-1") + ctx = context.WithValue(ctx, tools.MessageIDContextKey, "msg-1") + + input, err := json.Marshal(AgentDispatchParams{SubagentType: "broken", Prompt: "do it"}) + require.NoError(t, err) + + resp, err := dt.Run(ctx, fantasy.ToolCall{ID: "call-1", Input: string(input)}) + + // The turn must not abort: fantasy treats a non-nil error as critical. + require.NoError(t, err) + // The build failure must be surfaced as a tool-error response so the + // parent model can report it and continue. + require.True(t, resp.IsError) + // The subagent name must appear in the error message. + require.Contains(t, resp.Content, "broken") +} diff --git a/internal/agent/agenttest/coordinator.go b/internal/agent/agenttest/coordinator.go index fdacb7e129..c723661be6 100644 --- a/internal/agent/agenttest/coordinator.go +++ b/internal/agent/agenttest/coordinator.go @@ -76,5 +76,7 @@ func NewCoordinator( nil, nil, nil, + nil, + nil, ) } diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 86ca09e3bf..a30fe06eaf 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -15,6 +15,7 @@ import ( "path/filepath" "slices" "strings" + "sync" "charm.land/catwalk/pkg/catwalk" "charm.land/fantasy" @@ -35,6 +36,7 @@ import ( "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/skills" + "github.com/charmbracelet/crush/internal/subagents" "golang.org/x/sync/errgroup" "charm.land/fantasy/providers/anthropic" @@ -121,6 +123,20 @@ type coordinator struct { activeSkills []*skills.Skill // Post-filter: active skills only. skillTracker *skills.Tracker + // Subagents discovery. subagentsMgr is the live source of truth (its + // snapshot changes when the Library reloads); activeSubagents is a + // fallback snapshot used only when no manager was supplied (e.g. tests). + subagentsMgr *subagents.Manager + activeSubagents []*subagents.Subagent + + // runtime tracks which sub-agents are currently running. + runtime *subagents.Runtime + + // subagentModelCache memoizes resolveModelByID results within a config + // generation. Cleared by UpdateModels to avoid reusing stale clients. + subagentModelCache map[subagentModelKey]Model + subagentModelCacheMu sync.RWMutex + readyWg errgroup.Group } @@ -136,6 +152,8 @@ func NewCoordinator( notify pubsub.Publisher[notify.Notification], runComplete pubsub.Publisher[notify.RunComplete], skillsMgr *skills.Manager, + subagentsMgr *subagents.Manager, + runtime *subagents.Runtime, ) (Coordinator, error) { // Skills are pre-discovered by the caller (see app.New / // backend.CreateWorkspace) and passed in via the manager. If no @@ -151,20 +169,27 @@ func NewCoordinator( skillTracker := skills.NewTracker(activeSkills) c := &coordinator{ - cfg: cfg, - sessions: sessions, - messages: messages, - permissions: permissions, - history: history, - filetracker: filetracker, - lspManager: lspManager, - notify: notify, - runComplete: runComplete, - agents: make(map[string]SessionAgent), - allSkills: allSkills, - activeSkills: activeSkills, - skillTracker: skillTracker, - } + cfg: cfg, + sessions: sessions, + messages: messages, + permissions: permissions, + history: history, + filetracker: filetracker, + lspManager: lspManager, + notify: notify, + runComplete: runComplete, + agents: make(map[string]SessionAgent), + allSkills: allSkills, + activeSkills: activeSkills, + skillTracker: skillTracker, + subagentModelCache: make(map[subagentModelKey]Model), + } + + c.subagentsMgr = subagentsMgr + if subagentsMgr != nil { + c.activeSubagents = subagentsMgr.ActiveSubagents() + } + c.runtime = runtime agentCfg, ok := cfg.Config().Agents[config.AgentCoder] if !ok { @@ -177,7 +202,7 @@ func NewCoordinator( return nil, err } - agent, err := c.buildAgent(ctx, prompt, agentCfg, false) + agent, err := c.buildAgent(ctx, prompt, agentCfg, false, subagentModel{}) if err != nil { return nil, err } @@ -526,17 +551,166 @@ func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderO return modelOptions, temp, topP, topK, freqPenalty, presPenalty } -func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) { - large, small, err := c.buildAgentModels(ctx, isSubAgent) +// activeSubagentsList returns the current active subagents. It reads the live +// manager snapshot when available (so Library reloads are reflected without a +// restart) and falls back to the construction-time snapshot otherwise. +func (c *coordinator) activeSubagentsList() []*subagents.Subagent { + if c.subagentsMgr != nil { + return c.subagentsMgr.ActiveSubagents() + } + return c.activeSubagents +} + +// findModelProvider returns the provider config and catwalk model for the +// provider that offers modelID. When providerOverride is non-empty only that +// provider is searched. ok is false when no matching provider/model is found. +func (c *coordinator) findModelProvider(modelID, providerOverride string) (config.ProviderConfig, catwalk.Model, bool) { + if providerOverride != "" { + p, ok := c.cfg.Config().Providers.Get(providerOverride) + if !ok { + return config.ProviderConfig{}, catwalk.Model{}, false + } + m, ok := findCatwalkModel(p, modelID) + if !ok { + return config.ProviderConfig{}, catwalk.Model{}, false + } + return p, m, true + } + for p := range c.cfg.Config().Providers.Seq() { + if m, ok := findCatwalkModel(p, modelID); ok { + return p, m, true + } + } + return config.ProviderConfig{}, catwalk.Model{}, false +} + +// findCatwalkModel returns the catwalk model with the given id from a provider. +func findCatwalkModel(providerCfg config.ProviderConfig, modelID string) (catwalk.Model, bool) { + for _, m := range providerCfg.Models { + if m.ID == modelID { + return m, true + } + } + return catwalk.Model{}, false +} + +// buildModel constructs a Model from an already-resolved provider, selected +// model, and catwalk model. Shared by buildNamedModel and resolveModelByID. +func (c *coordinator) buildModel(ctx context.Context, providerCfg config.ProviderConfig, selModel config.SelectedModel, catwalkModel catwalk.Model, isSubAgent bool) (Model, error) { + provider, err := c.buildProvider(providerCfg, selModel, isSubAgent) + if err != nil { + return Model{}, err + } + modelID := selModel.Model + if providerCfg.ID == openrouter.Name && isExactoSupported(modelID) { + modelID += ":exacto" + } + lm, err := provider.LanguageModel(ctx, modelID) + if err != nil { + return Model{}, err + } + return Model{Model: lm, CatwalkCfg: catwalkModel, ModelCfg: selModel, FlatRate: providerCfg.FlatRate}, nil +} + +// buildNamedModel builds the large or small model selected in config. Errors +// distinguish the two so callers and tests can tell which model failed. +func (c *coordinator) buildNamedModel(ctx context.Context, modelType config.SelectedModelType, isSubAgent bool) (Model, error) { + isSmall := modelType == config.SelectedModelTypeSmall + selModel, ok := c.cfg.Config().Models[modelType] + if !ok { + if isSmall { + return Model{}, errSmallModelNotSelected + } + return Model{}, errLargeModelNotSelected + } + providerCfg, ok := c.cfg.Config().Providers.Get(selModel.Provider) + if !ok { + if isSmall { + return Model{}, errSmallModelProviderNotConfigured + } + return Model{}, errLargeModelProviderNotConfigured + } + catwalkModel, ok := findCatwalkModel(providerCfg, selModel.Model) + if !ok { + if isSmall { + return Model{}, errSmallModelNotFound + } + return Model{}, errLargeModelNotFound + } + return c.buildModel(ctx, providerCfg, selModel, catwalkModel, isSubAgent) +} + +// resolveModelByID finds the provider that offers modelID and builds a Model +// for it. It lets a subagent run on the specific model named in its `model:` +// frontmatter (validated at discovery via Config.IsKnownModel). When +// providerOverride is non-empty only that provider is searched. +// +// Results are memoized in subagentModelCache for the lifetime of the current +// config generation. UpdateModels clears the cache on config reload. +func (c *coordinator) resolveModelByID(ctx context.Context, modelID, providerOverride string, isSubAgent bool) (Model, error) { + key := subagentModelKey{modelID: modelID, provider: providerOverride, isSubAgent: isSubAgent} + + c.subagentModelCacheMu.RLock() + if m, ok := c.subagentModelCache[key]; ok { + c.subagentModelCacheMu.RUnlock() + return m, nil + } + c.subagentModelCacheMu.RUnlock() + + providerCfg, catwalkModel, ok := c.findModelProvider(modelID, providerOverride) + if !ok { + return Model{}, fmt.Errorf("model %q not found in any configured provider", modelID) + } + selModel := config.SelectedModel{Provider: providerCfg.ID, Model: modelID} + m, err := c.buildModel(ctx, providerCfg, selModel, catwalkModel, isSubAgent) + if err != nil { + return Model{}, err + } + + if c.subagentModelCache != nil { + c.subagentModelCacheMu.Lock() + c.subagentModelCache[key] = m + c.subagentModelCacheMu.Unlock() + } + return m, nil +} + +// buildAgent constructs a SessionAgent. sm carries the model-selection fields +// from subagent frontmatter (zero value for the coder/task agents): sm.Model is +// "" or "large" (global large), "small" (global small), or a specific model id +// resolved via resolveModelByID. sm.Effort is applied to the resolved primary, +// which is also the only large/specific model built — small always backs +// titles/summaries, so it is built unconditionally. +func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool, sm subagentModel) (SessionAgent, error) { + small, err := c.buildNamedModel(ctx, config.SelectedModelTypeSmall, true) if err != nil { return nil, err } - largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider) + var primary Model + switch sm.Model { + case "small": + primary = small + case "", "large": + primary, err = c.buildNamedModel(ctx, config.SelectedModelTypeLarge, isSubAgent) + default: + primary, err = c.resolveModelByID(ctx, sm.Model, sm.Provider, isSubAgent) + } + if err != nil { + return nil, err + } + + if subagents.EffortIgnored(sm.Effort, primary.CatwalkCfg) { + slog.Warn("Subagent effort ignored: model does not support reasoning", + "model", primary.ModelCfg.Model, "effort", sm.Effort) + } + primary.ModelCfg = subagents.ApplyEffortToModel(sm.Effort, primary.ModelCfg, primary.CatwalkCfg) + + primaryProviderCfg, _ := c.cfg.Config().Providers.Get(primary.ModelCfg.Provider) result := NewSessionAgent(SessionAgentOptions{ - LargeModel: large, + LargeModel: primary, SmallModel: small, - SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix, + SystemPromptPrefix: primaryProviderCfg.SystemPromptPrefix, SystemPrompt: "", IsSubAgent: isSubAgent, DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize, @@ -549,7 +723,7 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age }) c.readyWg.Go(func() error { - systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg) + systemPrompt, err := prompt.Build(ctx, primary.Model.Provider(), primary.Model.Model(), c.cfg) if err != nil { return err } @@ -558,7 +732,7 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age }) c.readyWg.Go(func() error { - tools, err := c.buildTools(ctx, agent, isSubAgent) + tools, err := c.buildTools(ctx, agent, isSubAgent, primary.CatwalkCfg.ID) if err != nil { return err } @@ -569,9 +743,22 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age return result, nil } -func (c *coordinator) buildTools(ctx context.Context, agent config.Agent, isSubAgent bool) ([]fantasy.AgentTool, error) { +// shouldExposeDispatcher reports whether the dispatcher agent tool should be +// included for an agent. Sub-agents never receive it — that prevents recursive +// delegation regardless of what their AllowedTools list contains. +func shouldExposeDispatcher(allowed []string, isSubAgent bool) bool { + if isSubAgent { + return false + } + return slices.Contains(allowed, AgentToolName) +} + +// buildTools assembles the agent's tool set. modelID is the catwalk id of the +// model the agent actually runs on (the resolved primary), used for +// model-specific tool guidance such as the bash tool description. +func (c *coordinator) buildTools(ctx context.Context, agent config.Agent, isSubAgent bool, modelID string) ([]fantasy.AgentTool, error) { var allTools []fantasy.AgentTool - if slices.Contains(agent.AllowedTools, AgentToolName) { + if shouldExposeDispatcher(agent.AllowedTools, isSubAgent) { agentTool, err := c.agentTool(ctx) if err != nil { return nil, err @@ -587,14 +774,6 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent, isSubA allTools = append(allTools, agenticFetchTool) } - // Get the model name for the agent - modelID := "" - if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok { - if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil { - modelID = model.ID - } - } - logFile := filepath.Join(c.cfg.Config().Options.DataDirectory, "logs", "crush.log") // Build hook runner if PreToolUse hooks are configured. @@ -682,88 +861,15 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent, isSubA // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) { - largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge] - if !ok { - return Model{}, Model{}, errLargeModelNotSelected - } - smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall] - if !ok { - return Model{}, Model{}, errSmallModelNotSelected - } - - largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider) - if !ok { - return Model{}, Model{}, errLargeModelProviderNotConfigured - } - - largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent) - if err != nil { - return Model{}, Model{}, err - } - - smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider) - if !ok { - return Model{}, Model{}, errSmallModelProviderNotConfigured - } - - smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true) - if err != nil { - return Model{}, Model{}, err - } - - var largeCatwalkModel *catwalk.Model - var smallCatwalkModel *catwalk.Model - - for _, m := range largeProviderCfg.Models { - if m.ID == largeModelCfg.Model { - largeCatwalkModel = &m - } - } - for _, m := range smallProviderCfg.Models { - if m.ID == smallModelCfg.Model { - smallCatwalkModel = &m - } - } - - if largeCatwalkModel == nil { - return Model{}, Model{}, errLargeModelNotFound - } - - if smallCatwalkModel == nil { - return Model{}, Model{}, errSmallModelNotFound - } - - largeModelID := largeModelCfg.Model - smallModelID := smallModelCfg.Model - - if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) { - largeModelID += ":exacto" - } - - if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) { - smallModelID += ":exacto" - } - - largeModel, err := largeProvider.LanguageModel(ctx, largeModelID) + large, err := c.buildNamedModel(ctx, config.SelectedModelTypeLarge, isSubAgent) if err != nil { return Model{}, Model{}, err } - smallModel, err := smallProvider.LanguageModel(ctx, smallModelID) + small, err := c.buildNamedModel(ctx, config.SelectedModelTypeSmall, true) if err != nil { return Model{}, Model{}, err } - - return Model{ - Model: largeModel, - CatwalkCfg: *largeCatwalkModel, - ModelCfg: largeModelCfg, - FlatRate: largeProviderCfg.FlatRate, - }, Model{ - Model: smallModel, - CatwalkCfg: *smallCatwalkModel, - ModelCfg: smallModelCfg, - FlatRate: smallProviderCfg.FlatRate, - }, nil + return large, small, nil } func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) { @@ -1079,6 +1185,12 @@ func (c *coordinator) Model() Model { } func (c *coordinator) UpdateModels(ctx context.Context) error { + // Clear the subagent model cache so that any stale LanguageModel instances + // (built against the old config) are not reused after a config reload. + c.subagentModelCacheMu.Lock() + c.subagentModelCache = make(map[subagentModelKey]Model) + c.subagentModelCacheMu.Unlock() + // build the models again so we make sure we get the latest config large, small, err := c.buildAgentModels(ctx, false) if err != nil { @@ -1091,7 +1203,7 @@ func (c *coordinator) UpdateModels(ctx context.Context) error { return errCoderAgentNotConfigured } - tools, err := c.buildTools(ctx, agentCfg, false) + tools, err := c.buildTools(ctx, agentCfg, false, large.CatwalkCfg.ID) if err != nil { return err } @@ -1137,11 +1249,15 @@ func (c *coordinator) refreshTokenIfExpired(ctx context.Context, providerCfg con // attempts to refresh credentials and re-runs fn once. Returns the // final error: from the retry if a retry was attempted, otherwise from // the original run. Callers that need to notify the user on persistent -// failure should check isUnauthorized on the returned error. -func (c *coordinator) runWithUnauthorizedRetry(ctx context.Context, providerCfg config.ProviderConfig, fn func() error) error { +// failure should check isUnauthorized on the returned error. The optional +// onRetry callback is invoked once just before the retry attempt. +func (c *coordinator) runWithUnauthorizedRetry(ctx context.Context, providerCfg config.ProviderConfig, fn func() error, onRetry ...func()) error { err := fn() if err != nil && c.isUnauthorized(err) { if retryErr := c.retryAfterUnauthorized(ctx, providerCfg); retryErr == nil { + for _, cb := range onRetry { + cb() + } return fn() } } @@ -1195,6 +1311,21 @@ func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg con return nil } +// subagentModel carries the model-selection fields from subagent frontmatter. +// The zero value selects the global large model. +type subagentModel struct { + Effort string + Model string + Provider string +} + +// subagentModelKey is the cache key for resolveModelByID results. +type subagentModelKey struct { + modelID string + provider string + isSubAgent bool +} + // subAgentParams holds the parameters for running a sub-agent. type subAgentParams struct { Agent SessionAgent @@ -1203,6 +1334,9 @@ type subAgentParams struct { ToolCallID string Prompt string SessionTitle string + AgentName string + AgentColor string + AgentModel string // SessionSetup is an optional callback invoked after session creation // but before agent execution, for custom session configuration. SessionSetup func(sessionID string) @@ -1224,6 +1358,12 @@ func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (f params.SessionSetup(session.ID) } + // Register with the runtime tracker and finish on return. finalStatus is + // captured by the deferred call and updated below based on the outcome. + c.runtime.Register(params.SessionID, session.ID, params.AgentName, params.AgentColor, params.AgentModel) + finalStatus := subagents.StatusCompleted + defer func() { c.runtime.Finish(session.ID, finalStatus) }() + // Get model configuration model := params.Agent.Model() maxTokens := model.CatwalkCfg.DefaultMaxTokens @@ -1256,6 +1396,8 @@ func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (f var runErr error result, runErr = run() return runErr + }, func() { + c.runtime.SetStatus(session.ID, "retrying") }) // Notify only if still unauthorized after retry. if err != nil && c.isUnauthorized(err) && c.notify != nil && model.ModelCfg.Provider == hyper.Name { @@ -1265,6 +1407,11 @@ func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (f }) } if err != nil { + if errors.Is(err, context.Canceled) { + finalStatus = subagents.StatusCancelled + return fantasy.NewTextErrorResponse("Subagent cancelled by user"), nil + } + finalStatus = subagents.StatusFailed return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to generate response: %s", err)), nil } diff --git a/internal/agent/coordinator_test.go b/internal/agent/coordinator_test.go index c522ef5de1..619d0e7cdf 100644 --- a/internal/agent/coordinator_test.go +++ b/internal/agent/coordinator_test.go @@ -10,6 +10,7 @@ import ( "charm.land/fantasy/providers/anthropic" "charm.land/fantasy/providers/bedrock" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/subagents" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -52,8 +53,9 @@ func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCf require.NoError(t, err) cfg.Config().Providers.Set(providerID, providerCfg) return &coordinator{ - cfg: cfg, - sessions: env.sessions, + cfg: cfg, + sessions: env.sessions, + subagentModelCache: make(map[subagentModelKey]Model), } } @@ -390,6 +392,360 @@ func TestUpdateParentSessionCost(t *testing.T) { }) } +// TestCoordinator_ActiveSubagentsFromManager verifies that coordinator.activeSubagents +// is populated from a Manager's ActiveSubagents slice. This test will fail to compile +// until the activeSubagents field exists on the coordinator struct. +func TestCoordinator_ActiveSubagentsFromManager(t *testing.T) { + t.Parallel() + + active := []*subagents.Subagent{ + {Name: "test-agent", Description: "A test agent"}, + {Name: "another-agent", Description: "Another test agent"}, + } + mgr := subagents.NewManager(active, active, nil) + + // Construct the coordinator directly (mirrors newTestCoordinator style). + // This fails to compile if activeSubagents field does not exist on coordinator. + c := &coordinator{ + activeSubagents: mgr.ActiveSubagents(), + } + + require.Len(t, c.activeSubagents, 2) + require.Equal(t, "test-agent", c.activeSubagents[0].Name) + require.Equal(t, "another-agent", c.activeSubagents[1].Name) +} + +// TestCoordinator_ActiveSubagentsNilManager verifies that coordinator.activeSubagents +// is nil (zero value) when no Manager is wired in. This test will fail to compile +// until the activeSubagents field exists on the coordinator struct. +func TestCoordinator_ActiveSubagentsNilManager(t *testing.T) { + t.Parallel() + + // Construct coordinator without setting activeSubagents — mirrors the + // nil-manager branch of NewCoordinator (no subagents wired). + c := &coordinator{} + + require.Nil(t, c.activeSubagents) +} + +// TestCoordinator_ActiveSubagentsFieldType verifies that the activeSubagents field +// has type []*subagents.Subagent. A direct struct literal assignment is used so the +// test fails to compile with a type mismatch if the field type is wrong. +func TestCoordinator_ActiveSubagentsFieldType(t *testing.T) { + t.Parallel() + + // This literal fails to compile if the field type is not []*subagents.Subagent. + c := &coordinator{ + activeSubagents: []*subagents.Subagent{ + {Name: "compile-check"}, + }, + } + + require.Len(t, c.activeSubagents, 1) + assert.Equal(t, "compile-check", c.activeSubagents[0].Name) +} + +// TestRunSubAgent_RegistersAndFinishesRuntime verifies that runSubAgent calls +// Register on the coordinator's Runtime after session creation and Finish when +// it returns, propagating AgentName, AgentColor and AgentModel from params. +func TestRunSubAgent_RegistersAndFinishesRuntime(t *testing.T) { + t.Parallel() + + const providerID = "test-provider" + providerCfg := config.ProviderConfig{ID: providerID} + + env := testEnv(t) + cfg, err := config.Init(env.workingDir, "", false) + require.NoError(t, err) + cfg.Config().Providers.Set(providerID, providerCfg) + + rt := subagents.NewRuntime() + t.Cleanup(rt.Shutdown) + + // Channel to capture the session ID used during the agent run so we can + // assert that List sees the entry while runSubAgent is in-flight. + type snapshot struct { + entries []subagents.RunningEntry + } + snapCh := make(chan snapshot, 1) + + parentSession, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + + agent := newMockAgent(providerID, 4096, func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) { + // Capture a snapshot of the Runtime state while the sub-agent is running. + snapCh <- snapshot{entries: rt.List(parentSession.ID)} + return agentResultWithText("done"), nil + }) + + coord := &coordinator{ + cfg: cfg, + sessions: env.sessions, + runtime: rt, + } + + _, err = coord.runSubAgent(t.Context(), subAgentParams{ + Agent: agent, + SessionID: parentSession.ID, + AgentMessageID: "msg-1", + ToolCallID: "call-1", + Prompt: "do something", + SessionTitle: "Runtime Test", + AgentName: "my-agent", + AgentColor: "blue", + AgentModel: "claude-test", + }) + require.NoError(t, err) + + // Verify the in-flight snapshot captured exactly one entry with correct fields. + select { + case snap := <-snapCh: + require.Len(t, snap.entries, 1, "Runtime must have one entry while runSubAgent is in-flight") + e := snap.entries[0] + require.Equal(t, parentSession.ID, e.ParentSessionID) + require.Equal(t, "my-agent", e.Name) + require.Equal(t, "blue", e.Color) + require.Equal(t, "claude-test", e.Model) + require.Equal(t, subagents.StatusRunning, e.Status) + require.False(t, e.StartedAt.IsZero()) + default: + t.Fatal("agent run function was never called") + } + + // After runSubAgent returns, the entry must be gone. + after := rt.List(parentSession.ID) + require.Empty(t, after, "Runtime must have no entries after runSubAgent returns") +} + +// TestResolveModelByID_UnknownErrors verifies resolveModelByID errors when no +// configured provider offers the requested model id. +func TestResolveModelByID_UnknownErrors(t *testing.T) { + t.Parallel() + + env := testEnv(t) + coord := newTestCoordinator(t, env, "p", config.ProviderConfig{ID: "p"}) + + _, err := coord.resolveModelByID(t.Context(), "no-such-model", "", true) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") +} + +// TestResolveModelByID_WithProviderOverride verifies that when a providerOverride +// is supplied, resolveModelByID restricts lookup to that provider. +func TestResolveModelByID_WithProviderOverride(t *testing.T) { + t.Parallel() + + env := testEnv(t) + providerCfg := config.ProviderConfig{ + ID: "test-provider", + Models: []catwalk.Model{{ID: "model-a"}}, + } + coord := newTestCoordinator(t, env, "test-provider", providerCfg) + + t.Run("unknown_provider_override_errors", func(t *testing.T) { + t.Parallel() + + _, err := coord.resolveModelByID(t.Context(), "model-a", "nonexistent-provider", true) + require.Error(t, err) + }) +} + +// TestFindModelProvider verifies the pure provider/catwalk lookup used by +// resolveModelByID to back a subagent's specific `model:` id. +func TestFindModelProvider(t *testing.T) { + t.Parallel() + + env := testEnv(t) + providerCfg := config.ProviderConfig{ + ID: "test-provider", + Models: []catwalk.Model{{ID: "model-a"}, {ID: "model-b"}}, + } + coord := newTestCoordinator(t, env, "test-provider", providerCfg) + + t.Run("found", func(t *testing.T) { + t.Parallel() + + pc, m, ok := coord.findModelProvider("model-b", "") + require.True(t, ok) + require.Equal(t, "test-provider", pc.ID) + require.Equal(t, "model-b", m.ID) + }) + + t.Run("unknown", func(t *testing.T) { + t.Parallel() + + _, _, ok := coord.findModelProvider("no-such-model", "") + require.False(t, ok) + }) + + t.Run("provider_override_match", func(t *testing.T) { + t.Parallel() + + pc, m, ok := coord.findModelProvider("model-a", "test-provider") + require.True(t, ok) + require.Equal(t, "test-provider", pc.ID) + require.Equal(t, "model-a", m.ID) + }) + + t.Run("provider_override_no_match_wrong_provider", func(t *testing.T) { + t.Parallel() + + _, _, ok := coord.findModelProvider("model-a", "other-provider") + require.False(t, ok) + }) + + t.Run("provider_override_no_match_unknown_model", func(t *testing.T) { + t.Parallel() + + _, _, ok := coord.findModelProvider("no-such-model", "test-provider") + require.False(t, ok) + }) +} + +// TestFindModelProvider_TwoProvidersSameModelID verifies behavior when two +// providers expose the same model ID. +func TestFindModelProvider_TwoProvidersSameModelID(t *testing.T) { + t.Parallel() + + env := testEnv(t) + cfg, err := config.Init(env.workingDir, "", false) + require.NoError(t, err) + + cfg.Config().Providers.Set("provider-a", config.ProviderConfig{ + ID: "provider-a", + Models: []catwalk.Model{{ID: "shared-model"}}, + }) + cfg.Config().Providers.Set("provider-b", config.ProviderConfig{ + ID: "provider-b", + Models: []catwalk.Model{{ID: "shared-model"}}, + }) + + coord := &coordinator{cfg: cfg, sessions: env.sessions} + + t.Run("no_override_returns_one_result_no_panic", func(t *testing.T) { + t.Parallel() + + pc, m, ok := coord.findModelProvider("shared-model", "") + require.True(t, ok, "must find shared-model in at least one provider") + require.Equal(t, "shared-model", m.ID) + require.NotEmpty(t, pc.ID) + }) + + t.Run("override_selects_specific_provider", func(t *testing.T) { + t.Parallel() + + pc, m, ok := coord.findModelProvider("shared-model", "provider-b") + require.True(t, ok) + require.Equal(t, "provider-b", pc.ID) + require.Equal(t, "shared-model", m.ID) + }) +} + +// TestBuildAgent_SubagentModel verifies that buildAgent accepts a subagentModel +// struct and routes model selection correctly. +func TestBuildAgent_SubagentModel(t *testing.T) { + t.Parallel() + + t.Run("zero_value_uses_large_model", func(t *testing.T) { + t.Parallel() + + env := testEnv(t) + coord := &coordinator{ + cfg: config.NewTestStoreWithWorkingDir(&config.Config{}, env.workingDir), + sessions: env.sessions, + } + + agentCfg := config.Agent{ + ID: "test", + Name: "test", + AllowedTools: []string{}, + } + + // Zero-value subagentModel must be accepted without panicking. With no + // models configured, buildNamedModel fails before any prompt is needed, + // verifying the struct parameter is wired into model-selection logic. + _, err := coord.buildAgent(t.Context(), nil, agentCfg, true, subagentModel{}) + require.Error(t, err) + require.Contains(t, err.Error(), "model") + }) +} + +// TestRunSubAgent_CancelledMapsToCancelled verifies that a context.Canceled +// from the agent run is mapped to the "cancelled" response (distinct from the +// generic failure), confirming the StatusCancelled branch is taken. +func TestRunSubAgent_CancelledMapsToCancelled(t *testing.T) { + t.Parallel() + + const providerID = "test-provider" + providerCfg := config.ProviderConfig{ID: providerID} + + env := testEnv(t) + cfg, err := config.Init(env.workingDir, "", false) + require.NoError(t, err) + cfg.Config().Providers.Set(providerID, providerCfg) + + rt := subagents.NewRuntime() + t.Cleanup(rt.Shutdown) + + parentSession, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + + agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) { + return nil, context.Canceled + }) + + coord := &coordinator{cfg: cfg, sessions: env.sessions, runtime: rt} + + resp, err := coord.runSubAgent(t.Context(), subAgentParams{ + Agent: agent, + SessionID: parentSession.ID, + AgentMessageID: "msg-1", + ToolCallID: "call-1", + Prompt: "do something", + SessionTitle: "Cancel Test", + AgentName: "a", + AgentColor: "red", + }) + require.NoError(t, err) + require.True(t, resp.IsError) + require.Equal(t, "Subagent cancelled by user", resp.Content) + + // The runtime entry must be gone after a cancelled run. + require.Empty(t, rt.List(parentSession.ID)) +} + +// TestActiveSubagentsList verifies the coordinator reads the live manager +// snapshot when present (so Library reloads are reflected) and falls back to +// the construction-time slice when no manager is wired. +func TestActiveSubagentsList(t *testing.T) { + t.Parallel() + + t.Run("live from manager reflects reload", func(t *testing.T) { + t.Parallel() + initial := []*subagents.Subagent{{Name: "x"}, {Name: "y"}} + mgr := subagents.NewManager(initial, initial, nil) + t.Cleanup(mgr.Shutdown) + c := &coordinator{subagentsMgr: mgr} + + require.Len(t, c.activeSubagentsList(), 2) + + reduced := []*subagents.Subagent{{Name: "x"}} + mgr.Reload(reduced, reduced, nil) + + got := c.activeSubagentsList() + require.Len(t, got, 1) + require.Equal(t, "x", got[0].Name) + }) + + t.Run("fallback when nil manager", func(t *testing.T) { + t.Parallel() + c := &coordinator{activeSubagents: []*subagents.Subagent{{Name: "z"}}} + got := c.activeSubagentsList() + require.Len(t, got, 1) + require.Equal(t, "z", got[0].Name) + }) +} + func TestGetProviderOptionsReasoningEffort(t *testing.T) { // Bedrock is Fantasy's Anthropic under a different provider name; options // must land under anthropic.Name so the Anthropic language model picks them up. @@ -426,3 +782,81 @@ func TestGetProviderOptionsReasoningEffort(t *testing.T) { }) } } + +// TestUpdateModels_ClearsSubagentModelCache verifies that UpdateModels empties +// the subagent model cache so stale LanguageModel instances are not reused +// after a config reload, even when UpdateModels itself returns an error. +func TestUpdateModels_ClearsSubagentModelCache(t *testing.T) { + t.Parallel() + + env := testEnv(t) + coord := &coordinator{ + cfg: config.NewTestStoreWithWorkingDir(&config.Config{}, env.workingDir), + sessions: env.sessions, + subagentModelCache: make(map[subagentModelKey]Model), + } + + // Manually populate the cache with a dummy entry. + coord.subagentModelCache[subagentModelKey{modelID: "some-model", provider: "", isSubAgent: true}] = Model{} + + require.Len(t, coord.subagentModelCache, 1) + + // UpdateModels will error (no models configured in empty config), but the + // cache must be cleared regardless. + _ = coord.UpdateModels(t.Context()) + + require.Empty(t, coord.subagentModelCache) +} + +// TestResolveModelByID_CacheHitSkipsBuild verifies that a second call to +// resolveModelByID with the same arguments returns the cached Model without +// repeating the provider build, and that errors are not cached. +func TestResolveModelByID_CacheHitSkipsBuild(t *testing.T) { + t.Parallel() + + env := testEnv(t) + // No-op provider with a known model so findModelProvider succeeds. + providerCfg := config.ProviderConfig{ + ID: "test-provider", + Models: []catwalk.Model{{ID: "model-x", DefaultMaxTokens: 4096}}, + } + coord := newTestCoordinator(t, env, "test-provider", providerCfg) + + // First call — cache is empty. + require.Empty(t, coord.subagentModelCache) + + _, err := coord.resolveModelByID(t.Context(), "model-x", "test-provider", true) + if err != nil { + // Provider construction may fail in the test environment (fake API key). + // Errors must not be cached. + require.Empty(t, coord.subagentModelCache, "failed build must not populate cache") + return + } + + // Success path: cache must contain exactly one entry. + require.Len(t, coord.subagentModelCache, 1) + + // Second call must hit the cache (same result, no error). + _, err2 := coord.resolveModelByID(t.Context(), "model-x", "test-provider", true) + require.NoError(t, err2) + require.Len(t, coord.subagentModelCache, 1, "second call must not add a new entry") +} + +// TestResolveModelByID_ModelNotFound verifies that resolveModelByID returns an +// error containing "not found" when no configured provider offers the requested +// model id, and that the cache is not populated on failure. +func TestResolveModelByID_ModelNotFound(t *testing.T) { + t.Parallel() + + env := testEnv(t) + providerCfg := config.ProviderConfig{ + ID: "test-provider", + Models: []catwalk.Model{{ID: "model-x", DefaultMaxTokens: 4096}}, + } + coord := newTestCoordinator(t, env, "test-provider", providerCfg) + + _, err := coord.resolveModelByID(t.Context(), "does-not-exist", "", true) + require.Error(t, err) + require.ErrorContains(t, err, "not found") + require.Empty(t, coord.subagentModelCache) +} diff --git a/internal/agent/dispatcher_gate_test.go b/internal/agent/dispatcher_gate_test.go new file mode 100644 index 0000000000..558097fb0a --- /dev/null +++ b/internal/agent/dispatcher_gate_test.go @@ -0,0 +1,50 @@ +package agent + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestShouldExposeDispatcher(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + allowed []string + isSubAgent bool + want bool + }{ + { + name: "top_level_with_agent_in_allowed", + allowed: []string{"bash", AgentToolName}, + isSubAgent: false, + want: true, + }, + { + name: "top_level_without_agent_in_allowed", + allowed: []string{"bash", "grep"}, + isSubAgent: false, + want: false, + }, + { + name: "subagent_with_agent_in_allowed_still_excluded", + allowed: []string{"bash", AgentToolName}, + isSubAgent: true, + want: false, + }, + { + name: "subagent_without_agent_excluded", + allowed: []string{"bash"}, + isSubAgent: true, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, shouldExposeDispatcher(tt.allowed, tt.isSubAgent)) + }) + } +} diff --git a/internal/agent/effort_dispatch_test.go b/internal/agent/effort_dispatch_test.go new file mode 100644 index 0000000000..d1dac6639f --- /dev/null +++ b/internal/agent/effort_dispatch_test.go @@ -0,0 +1,87 @@ +package agent + +import ( + "testing" + + "charm.land/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/subagents" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// These tests pin the effort-application contract that buildAgent relies on: +// it calls subagents.ApplyEffortToModel on the resolved primary model +// (primary.ModelCfg, primary.CatwalkCfg). The cases mirror the dispatch path +// taken when a named subagent has Effort set. + +func TestApplyEffortToModel_HighEffort_OpenAI(t *testing.T) { + t.Parallel() + + cfg := config.SelectedModel{Model: "o4-mini", Provider: "openai"} + cat := catwalk.Model{ + ID: "o4-mini", + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high"}, + } + + got := subagents.ApplyEffortToModel("high", cfg, cat) + + require.Equal(t, "high", got.ReasoningEffort, + "dispatch path must propagate high effort to ReasoningEffort") +} + +func TestApplyEffortToModel_HighEffort_Anthropic(t *testing.T) { + t.Parallel() + + cfg := config.SelectedModel{Model: "claude-opus-4-7", Provider: "anthropic"} + cat := catwalk.Model{ + ID: "claude-opus-4-7", + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high", "xhigh", "max"}, + } + + got := subagents.ApplyEffortToModel("high", cfg, cat) + + require.Equal(t, "high", got.ReasoningEffort, + "must set ReasoningEffort for an Anthropic model with high effort") + assert.False(t, got.Think, "Think must never be set by ApplyEffortToModel") +} + +func TestApplyEffortToModel_EmptyEffort_NoOp(t *testing.T) { + t.Parallel() + + cfg := config.SelectedModel{Model: "o4-mini", Provider: "openai"} + cat := catwalk.Model{ID: "o4-mini", CanReason: true, ReasoningLevels: []string{"low", "high"}} + + got := subagents.ApplyEffortToModel("", cfg, cat) + + assert.Empty(t, got.ReasoningEffort, "empty effort must not set ReasoningEffort") + assert.False(t, got.Think, "empty effort must not set Think") +} + +func TestApplyEffortToModel_PreservesOtherFields(t *testing.T) { + t.Parallel() + + cfg := config.SelectedModel{Model: "o4-mini", Provider: "openai", MaxTokens: 4096} + cat := catwalk.Model{ID: "o4-mini", CanReason: true, ReasoningLevels: []string{"low", "medium", "high"}} + + got := subagents.ApplyEffortToModel("high", cfg, cat) + + assert.Equal(t, "o4-mini", got.Model) + assert.Equal(t, "openai", got.Provider) + assert.Equal(t, int64(4096), got.MaxTokens) + assert.Equal(t, "high", got.ReasoningEffort) +} + +func TestApplyEffortToModel_NonReasoningModel(t *testing.T) { + t.Parallel() + + cfg := config.SelectedModel{Model: "gpt-4o", Provider: "openai"} + cat := catwalk.Model{ID: "gpt-4o", CanReason: false} + + got := subagents.ApplyEffortToModel("high", cfg, cat) + + assert.Empty(t, got.ReasoningEffort, "non-reasoning model must not have ReasoningEffort set") + assert.False(t, got.Think, "non-reasoning model must not have Think set") +} diff --git a/internal/agent/prompt/prompt.go b/internal/agent/prompt/prompt.go index 7609661f31..b17329ff89 100644 --- a/internal/agent/prompt/prompt.go +++ b/internal/agent/prompt/prompt.go @@ -9,7 +9,7 @@ import ( "path/filepath" "runtime" "strings" - "text/template" + "text/template" // nosemgrep: go.lang.security.audit.xss.import-text-template.import-text-template "time" "github.com/charmbracelet/crush/internal/config" @@ -21,24 +21,32 @@ import ( // Prompt represents a template-based prompt generator. type Prompt struct { - name string - template string - now func() time.Time - platform string - workingDir string + name string + template string + now func() time.Time + platform string + workingDir string + subagentBody string + preloadedSkillsXML string + // suppressAvailableSkills omits the discovery list. Set + // for subagents that pin an explicit skills set, so the preloaded skills are + // their only skill exposure. + suppressAvailableSkills bool } type PromptDat struct { - Provider string - Model string - Config config.Config - WorkingDir string - IsGitRepo bool - Platform string - Date string - GitStatus string - ContextFiles []ContextFile - AvailSkillXML string + Provider string + Model string + Config config.Config + WorkingDir string + IsGitRepo bool + Platform string + Date string + GitStatus string + ContextFiles []ContextFile + AvailSkillXML string + SubagentBody string + PreloadedSkillsXML string } type ContextFile struct { @@ -66,6 +74,18 @@ func WithWorkingDir(workingDir string) Option { } } +func WithSuppressAvailableSkills(suppress bool) Option { + return func(p *Prompt) { p.suppressAvailableSkills = suppress } +} + +func WithSubagentBody(body string) Option { + return func(p *Prompt) { p.subagentBody = body } +} + +func WithPreloadedSkillsXML(xml string) Option { + return func(p *Prompt) { p.preloadedSkillsXML = xml } +} + func NewPrompt(name, promptTemplate string, opts ...Option) (*Prompt, error) { p := &Prompt{ name: name, @@ -148,6 +168,15 @@ func expandPath(path string, store *config.ConfigStore) string { } func (p *Prompt) promptData(ctx context.Context, provider, model string, store *config.ConfigStore) (PromptDat, error) { + if store == nil { + return PromptDat{ + Provider: provider, + Model: model, + SubagentBody: p.subagentBody, + PreloadedSkillsXML: p.preloadedSkillsXML, + }, nil + } + workingDir := cmp.Or(p.workingDir, store.WorkingDir()) platform := cmp.Or(p.platform, runtime.GOOS) @@ -164,50 +193,55 @@ func (p *Prompt) promptData(ctx context.Context, provider, model string, store * files[pathKey] = content } - // Discover and load skills metadata. + // Discover and load skills metadata. Skipped entirely when the prompt + // suppresses the available-skills list (subagents that pin an explicit + // skills set), which also avoids the discovery filesystem walk. var availSkillXML string - - // Start with builtin skills. - allSkills := skills.DiscoverBuiltin() - builtinNames := make(map[string]bool, len(allSkills)) - for _, s := range allSkills { - builtinNames[s.Name] = true - } - - // Discover user skills from configured paths. - if len(cfg.Options.SkillsPaths) > 0 { - expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths)) - for _, pth := range cfg.Options.SkillsPaths { - expandedPaths = append(expandedPaths, expandPath(pth, store)) + if !p.suppressAvailableSkills { + // Start with builtin skills. + allSkills := skills.DiscoverBuiltin() + builtinNames := make(map[string]bool, len(allSkills)) + for _, s := range allSkills { + builtinNames[s.Name] = true } - for _, userSkill := range skills.Discover(expandedPaths) { - if builtinNames[userSkill.Name] { - slog.Warn("User skill overrides builtin skill", "name", userSkill.Name) + + // Discover user skills from configured paths. + if len(cfg.Options.SkillsPaths) > 0 { + expandedPaths := make([]string, 0, len(cfg.Options.SkillsPaths)) + for _, pth := range cfg.Options.SkillsPaths { + expandedPaths = append(expandedPaths, expandPath(pth, store)) + } + for _, userSkill := range skills.Discover(expandedPaths) { + if builtinNames[userSkill.Name] { + slog.Warn("User skill overrides builtin skill", "name", userSkill.Name) + } + allSkills = append(allSkills, userSkill) } - allSkills = append(allSkills, userSkill) } - } - // Deduplicate: user skills override builtins with the same name. - allSkills = skills.Deduplicate(allSkills) + // Deduplicate: user skills override builtins with the same name. + allSkills = skills.Deduplicate(allSkills) - // Filter out disabled skills. - allSkills = skills.Filter(allSkills, cfg.Options.DisabledSkills) + // Filter out disabled skills. + allSkills = skills.Filter(allSkills, cfg.Options.DisabledSkills) - if len(allSkills) > 0 { - availSkillXML = skills.ToPromptXML(allSkills) + if len(allSkills) > 0 { + availSkillXML = skills.ToPromptXML(allSkills) + } } isGit := isGitRepo(store.WorkingDir()) data := PromptDat{ - Provider: provider, - Model: model, - Config: *cfg, - WorkingDir: filepath.ToSlash(workingDir), - IsGitRepo: isGit, - Platform: platform, - Date: p.now().Format("1/2/2006"), - AvailSkillXML: availSkillXML, + Provider: provider, + Model: model, + Config: *cfg, + WorkingDir: filepath.ToSlash(workingDir), + IsGitRepo: isGit, + Platform: platform, + Date: p.now().Format("1/2/2006"), + AvailSkillXML: availSkillXML, + SubagentBody: p.subagentBody, + PreloadedSkillsXML: p.preloadedSkillsXML, } if isGit { var err error diff --git a/internal/agent/prompt/subagent_prompt_options_test.go b/internal/agent/prompt/subagent_prompt_options_test.go new file mode 100644 index 0000000000..8fdfb3e392 --- /dev/null +++ b/internal/agent/prompt/subagent_prompt_options_test.go @@ -0,0 +1,127 @@ +package prompt + +import ( + "context" + "testing" + + "github.com/charmbracelet/crush/internal/config" + "github.com/stretchr/testify/require" +) + +// TestWithSuppressAvailableSkills verifies the option omits +// from the rendered prompt even though builtin skills exist. Needs a real store +// because the nil-store path never computes AvailSkillXML. +func TestWithSuppressAvailableSkills(t *testing.T) { + t.Parallel() + + store, err := config.Init(t.TempDir(), "", false) + require.NoError(t, err) + + const tmpl = `{{.AvailSkillXML}}` + + open, err := NewPrompt("t", tmpl) + require.NoError(t, err) + got, err := open.Build(context.Background(), "p", "m", store) + require.NoError(t, err) + require.Contains(t, got, "", "available skills render by default") + + suppressed, err := NewPrompt("t", tmpl, WithSuppressAvailableSkills(true)) + require.NoError(t, err) + got, err = suppressed.Build(context.Background(), "p", "m", store) + require.NoError(t, err) + require.NotContains(t, got, "", "available skills suppressed by option") +} + +// TestWithSubagentBody verifies that WithSubagentBody stores the body string in +// PromptDat.SubagentBody and that the template can render it. +func TestWithSubagentBody(t *testing.T) { + t.Parallel() + + const body = "You are a specialist agent that does things." + + // Use a template that renders SubagentBody so we can observe the value + // without needing access to the unexported promptData method. + p, err := NewPrompt("test", `{{.SubagentBody}}`, WithSubagentBody(body)) + require.NoError(t, err) + + // A nil store makes promptData return a minimal PromptDat (it otherwise + // needs store.WorkingDir()), which still carries the subagent option fields. + result, err := p.Build(context.Background(), "test-provider", "test-model", nil) + require.NoError(t, err) + require.Equal(t, body, result) +} + +// TestWithPreloadedSkillsXML verifies that WithPreloadedSkillsXML stores the +// XML string in PromptDat.PreloadedSkillsXML and that the template can render it. +func TestWithPreloadedSkillsXML(t *testing.T) { + t.Parallel() + + const xml = "\n my-skill\n" + + p, err := NewPrompt("test", `{{.PreloadedSkillsXML}}`, WithPreloadedSkillsXML(xml)) + require.NoError(t, err) + + result, err := p.Build(context.Background(), "test-provider", "test-model", nil) + require.NoError(t, err) + require.Equal(t, xml, result) +} + +// TestSubagentPromptOptions_BothFieldsInTemplate verifies that both +// SubagentBody and PreloadedSkillsXML are accessible from the template when +// both options are provided. +func TestSubagentPromptOptions_BothFieldsInTemplate(t *testing.T) { + t.Parallel() + + const ( + body = "Do the specialist thing." + xml = "test-skill" + ) + + tmpl := `{{.SubagentBody}}|{{.PreloadedSkillsXML}}` + p, err := NewPrompt("test", tmpl, WithSubagentBody(body), WithPreloadedSkillsXML(xml)) + require.NoError(t, err) + + result, err := p.Build(context.Background(), "test-provider", "test-model", nil) + require.NoError(t, err) + require.Equal(t, body+"|"+xml, result) +} + +// TestSubagentPromptOptions_DefaultsToEmpty verifies that SubagentBody and +// PreloadedSkillsXML are empty strings when neither option is provided. +func TestSubagentPromptOptions_DefaultsToEmpty(t *testing.T) { + t.Parallel() + + tmpl := `body=«{{.SubagentBody}}»xml=«{{.PreloadedSkillsXML}}»` + p, err := NewPrompt("test", tmpl) + require.NoError(t, err) + + result, err := p.Build(context.Background(), "test-provider", "test-model", nil) + require.NoError(t, err) + require.Equal(t, "body=«»xml=«»", result) +} + +// TestWithSubagentBody_EmptyString verifies that an empty body string is stored +// and rendered correctly (no panic, no unexpected fallback). +func TestWithSubagentBody_EmptyString(t *testing.T) { + t.Parallel() + + p, err := NewPrompt("test", `{{.SubagentBody}}`, WithSubagentBody("")) + require.NoError(t, err) + + result, err := p.Build(context.Background(), "test-provider", "test-model", nil) + require.NoError(t, err) + require.Equal(t, "", result) +} + +// TestWithPreloadedSkillsXML_EmptyString verifies that an empty XML string is +// stored and rendered correctly. +func TestWithPreloadedSkillsXML_EmptyString(t *testing.T) { + t.Parallel() + + p, err := NewPrompt("test", `{{.PreloadedSkillsXML}}`, WithPreloadedSkillsXML("")) + require.NoError(t, err) + + result, err := p.Build(context.Background(), "test-provider", "test-model", nil) + require.NoError(t, err) + require.Equal(t, "", result) +} diff --git a/internal/agent/prompts.go b/internal/agent/prompts.go index 448fe0425c..ec863d2066 100644 --- a/internal/agent/prompts.go +++ b/internal/agent/prompts.go @@ -3,9 +3,13 @@ package agent import ( "context" _ "embed" + "log/slog" + "strings" "github.com/charmbracelet/crush/internal/agent/prompt" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/skills" + "github.com/charmbracelet/crush/internal/subagents" ) //go:embed templates/coder.md.tpl @@ -17,6 +21,9 @@ var taskPromptTmpl []byte //go:embed templates/initialize.md.tpl var initializePromptTmpl []byte +//go:embed templates/subagent.md.tpl +var subagentPromptTmpl []byte + func coderPrompt(opts ...prompt.Option) (*prompt.Prompt, error) { systemPrompt, err := prompt.NewPrompt("coder", string(coderPromptTmpl), opts...) if err != nil { @@ -33,6 +40,45 @@ func taskPrompt(opts ...prompt.Option) (*prompt.Prompt, error) { return systemPrompt, nil } +func resolvePreloadedSkillsXML(skillNames []string, activeSkills []*skills.Skill) string { + if len(skillNames) == 0 { + return "" + } + byName := make(map[string]*skills.Skill, len(activeSkills)) + for _, s := range activeSkills { + byName[s.Name] = s + } + var parts []string + for _, name := range skillNames { + s, ok := byName[name] + if !ok { + slog.Warn("Subagent references unknown skill", "skill", name) + continue + } + if s.DisableModelInvocation { + slog.Warn("Subagent references skill with disable-model-invocation, skipping", "skill", name) + continue + } + parts = append(parts, s.FormatInvocation()) + } + return strings.Join(parts, "\n") +} + +func subagentPrompt(sa *subagents.Subagent, activeSkills []*skills.Skill, opts ...prompt.Option) (*prompt.Prompt, error) { + preloadedXML := resolvePreloadedSkillsXML(sa.Skills, activeSkills) + allOpts := make([]prompt.Option, 0, len(opts)+3) + allOpts = append( + allOpts, + prompt.WithSubagentBody(sa.Body), + prompt.WithPreloadedSkillsXML(preloadedXML), + // A pinned skills set is the subagent's only skill exposure: suppress + // the broad discovery list so it can't reach others. + prompt.WithSuppressAvailableSkills(len(sa.Skills) > 0), + ) + allOpts = append(allOpts, opts...) + return prompt.NewPrompt("subagent", string(subagentPromptTmpl), allOpts...) +} + func InitializePrompt(cfg *config.ConfigStore) (string, error) { systemPrompt, err := prompt.NewPrompt("initialize", string(initializePromptTmpl)) if err != nil { diff --git a/internal/agent/subagent_prompt_test.go b/internal/agent/subagent_prompt_test.go new file mode 100644 index 0000000000..840a51d4ca --- /dev/null +++ b/internal/agent/subagent_prompt_test.go @@ -0,0 +1,322 @@ +package agent + +import ( + "context" + "strings" + "testing" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/skills" + "github.com/charmbracelet/crush/internal/subagents" + "github.com/stretchr/testify/require" +) + +// newTestSkill constructs a minimal *skills.Skill for use in unit tests. +// The name must pass skills.Validate, so it must be alphanumeric with hyphens. +func newTestSkill(name string, disableModelInvocation bool) *skills.Skill { + return &skills.Skill{ + Name: name, + Description: "test skill: " + name, + DisableModelInvocation: disableModelInvocation, + Instructions: "do the " + name + " thing", + } +} + +// newTestSubagent constructs a minimal *subagents.Subagent whose fields satisfy +// any validation that subagentPrompt may perform internally. +func newTestSubagent(name string, skillNames []string, body string) *subagents.Subagent { + return &subagents.Subagent{ + Name: name, + Description: "test subagent " + name, + Skills: skillNames, + Body: body, + } +} + +// --------------------------------------------------------------------------- +// resolvePreloadedSkillsXML +// --------------------------------------------------------------------------- + +func TestResolvePreloadedSkillsXML(t *testing.T) { + t.Parallel() + + alpha := newTestSkill("alpha", false) + beta := newTestSkill("beta", false) + gamma := newTestSkill("gamma", true) // DisableModelInvocation = true + + tests := []struct { + name string + skillNames []string + activeSkills []*skills.Skill + // wantContains is a slice of strings that must appear in the result. + wantContains []string + // wantAbsent is a slice of strings that must NOT appear in the result. + wantAbsent []string + // wantEmpty asserts the result is the empty string. + wantEmpty bool + }{ + { + name: "empty_skill_names", + skillNames: nil, + activeSkills: []*skills.Skill{ + alpha, + }, + wantEmpty: true, + }, + { + name: "empty_active_skills", + skillNames: []string{"alpha"}, + activeSkills: nil, + wantEmpty: true, + }, + { + name: "single_skill_found", + skillNames: []string{"alpha"}, + activeSkills: []*skills.Skill{alpha}, + wantContains: []string{"alpha"}, + }, + { + name: "skill_not_found", + skillNames: []string{"missing"}, + activeSkills: []*skills.Skill{alpha, beta}, + wantEmpty: true, + }, + { + name: "disable_model_invocation_skipped", + skillNames: []string{"gamma"}, + activeSkills: []*skills.Skill{gamma}, + wantEmpty: true, + }, + { + name: "multiple_skills_some_found", + skillNames: []string{"alpha", "missing", "beta"}, + activeSkills: []*skills.Skill{alpha, beta}, + wantContains: []string{"alpha", "beta"}, + wantAbsent: []string{"missing"}, + }, + { + name: "preserves_order", + skillNames: []string{"beta", "alpha"}, + activeSkills: []*skills.Skill{alpha, beta}, + // beta's FormatInvocation output must appear before alpha's in the result + wantContains: []string{"beta", "alpha"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := resolvePreloadedSkillsXML(tc.skillNames, tc.activeSkills) + + if tc.wantEmpty { + require.Empty(t, got) + return + } + + for _, want := range tc.wantContains { + require.Contains(t, got, want) + } + + for _, absent := range tc.wantAbsent { + require.NotContains(t, got, absent) + } + }) + } +} + +// TestResolvePreloadedSkillsXML_PreservesOrder verifies that when multiple +// skills are requested the output XML segments appear in skillNames order, not +// in activeSkills order. +func TestResolvePreloadedSkillsXML_PreservesOrder(t *testing.T) { + t.Parallel() + + alpha := newTestSkill("alpha", false) + beta := newTestSkill("beta", false) + + // Request beta before alpha even though activeSkills has alpha first. + got := resolvePreloadedSkillsXML([]string{"beta", "alpha"}, []*skills.Skill{alpha, beta}) + + betaIdx := strings.Index(got, "beta") + alphaIdx := strings.Index(got, "alpha") + + require.NotEqual(t, -1, betaIdx, "beta should appear in result") + require.NotEqual(t, -1, alphaIdx, "alpha should appear in result") + require.Less(t, betaIdx, alphaIdx, "beta should appear before alpha in output") +} + +// TestResolvePreloadedSkillsXML_FormatInvocationUsed verifies that the output +// for a found skill is derived from FormatInvocation() and therefore contains +// the wrapper element. +func TestResolvePreloadedSkillsXML_FormatInvocationUsed(t *testing.T) { + t.Parallel() + + sk := newTestSkill("my-skill", false) + got := resolvePreloadedSkillsXML([]string{"my-skill"}, []*skills.Skill{sk}) + + // FormatInvocation always wraps output in . + require.Contains(t, got, "") + require.Contains(t, got, "my-skill") +} + +// --------------------------------------------------------------------------- +// subagentPrompt +// --------------------------------------------------------------------------- + +func TestSubagentPrompt_NoSkills(t *testing.T) { + t.Parallel() + + sa := newTestSubagent("no-skills-agent", nil, "You do things.") + p, err := subagentPrompt(sa, nil) + + require.NoError(t, err) + require.NotNil(t, p) +} + +func TestSubagentPrompt_WithKnownSkill(t *testing.T) { + t.Parallel() + + sk := newTestSkill("helper-skill", false) + sa := newTestSubagent("skilled-agent", []string{"helper-skill"}, "You use the helper skill.") + + p, err := subagentPrompt(sa, []*skills.Skill{sk}) + + require.NoError(t, err) + require.NotNil(t, p) +} + +func TestSubagentPrompt_WithUnknownSkill(t *testing.T) { + t.Parallel() + + // Subagent requests a skill that is not in activeSkills — must not error. + sa := newTestSubagent("unknown-skill-agent", []string{"nonexistent"}, "Body text.") + + p, err := subagentPrompt(sa, nil) + + require.NoError(t, err) + require.NotNil(t, p) +} + +func TestSubagentPrompt_NilSubagentSkills(t *testing.T) { + t.Parallel() + + sa := newTestSubagent("nil-skills-agent", nil, "") + activeSkills := []*skills.Skill{newTestSkill("some-skill", false)} + + p, err := subagentPrompt(sa, activeSkills) + + require.NoError(t, err) + require.NotNil(t, p) +} + +// TestSubagentPrompt_Build_RendersBody confirms that the subagent template +// actually emits SubagentBody when the prompt is built. A typo in +// subagent.md.tpl would otherwise pass the existing not-nil tests. +func TestSubagentPrompt_Build_RendersBody(t *testing.T) { + t.Parallel() + + body := "You are a specialist that handles XYZ tasks." + sa := newTestSubagent("build-render", nil, body) + + p, err := subagentPrompt(sa, nil) + require.NoError(t, err) + + got, err := p.Build(context.Background(), "p", "m", nil) + require.NoError(t, err) + require.Contains(t, got, body) +} + +// TestSubagentPrompt_Build_RendersPreloadedSkillsXML confirms that resolved +// preloaded skill invocations actually flow through to the rendered prompt. +func TestSubagentPrompt_Build_RendersPreloadedSkillsXML(t *testing.T) { + t.Parallel() + + sk := newTestSkill("preload-me", false) + sa := newTestSubagent("with-preload", []string{"preload-me"}, "Body.") + + p, err := subagentPrompt(sa, []*skills.Skill{sk}) + require.NoError(t, err) + + got, err := p.Build(context.Background(), "p", "m", nil) + require.NoError(t, err) + require.Contains(t, got, "") + require.Contains(t, got, "preload-me") +} + +// TestSubagentPrompt_Build_OmitsPreloadWhenEmpty verifies that the template's +// guard around PreloadedSkillsXML keeps the output clean when no skills are +// requested. Catches accidental literal `` leak. +func TestSubagentPrompt_Build_OmitsPreloadWhenEmpty(t *testing.T) { + t.Parallel() + + sa := newTestSubagent("no-preload", nil, "Body.") + + p, err := subagentPrompt(sa, nil) + require.NoError(t, err) + + got, err := p.Build(context.Background(), "p", "m", nil) + require.NoError(t, err) + require.NotContains(t, got, "") +} + +// TestSubagentPrompt_Build_SuppressesAvailableWhenSkillsPinned verifies that a +// subagent with a pinned skills set gets its skills preloaded and the broad +// discovery list suppressed. Uses a real store so promptData +// runs the full path (the nil-store path never computes AvailSkillXML). +func TestSubagentPrompt_Build_SuppressesAvailableWhenSkillsPinned(t *testing.T) { + t.Parallel() + + store, err := config.Init(t.TempDir(), "", false) + require.NoError(t, err) + + sk := newTestSkill("preload-me", false) + sa := newTestSubagent("scoped", []string{"preload-me"}, "Body.") + + p, err := subagentPrompt(sa, []*skills.Skill{sk}) + require.NoError(t, err) + + got, err := p.Build(context.Background(), "p", "m", store) + require.NoError(t, err) + require.Contains(t, got, "", "pinned skill must be preloaded") + require.NotContains(t, got, "", "available list must be suppressed when skills are pinned") +} + +// TestSubagentPrompt_Build_RendersAvailableWhenNoSkillsPinned verifies the +// default (no skills:): the discovery list renders (builtins present) and +// nothing is preloaded. +func TestSubagentPrompt_Build_RendersAvailableWhenNoSkillsPinned(t *testing.T) { + t.Parallel() + + store, err := config.Init(t.TempDir(), "", false) + require.NoError(t, err) + + sa := newTestSubagent("open", nil, "Body.") + + p, err := subagentPrompt(sa, nil) + require.NoError(t, err) + + got, err := p.Build(context.Background(), "p", "m", store) + require.NoError(t, err) + require.Contains(t, got, "", "available list must render when no skills are pinned") + require.NotContains(t, got, "") +} + +// TestSubagentPrompt_Build_SuppressesEvenWhenSkillsUnresolved documents the +// accepted lenient edge: a pinned-but-unknown skill name suppresses available +// yet preloads nothing, so the subagent gets no skills section at all. +func TestSubagentPrompt_Build_SuppressesEvenWhenSkillsUnresolved(t *testing.T) { + t.Parallel() + + store, err := config.Init(t.TempDir(), "", false) + require.NoError(t, err) + + sa := newTestSubagent("scoped-typo", []string{"does-not-exist"}, "Body.") + + p, err := subagentPrompt(sa, nil) + require.NoError(t, err) + + got, err := p.Build(context.Background(), "p", "m", store) + require.NoError(t, err) + require.NotContains(t, got, "") + require.NotContains(t, got, "") +} diff --git a/internal/agent/templates/subagent.md.tpl b/internal/agent/templates/subagent.md.tpl new file mode 100644 index 0000000000..f768103c1e --- /dev/null +++ b/internal/agent/templates/subagent.md.tpl @@ -0,0 +1,18 @@ +{{- if .SubagentBody}} +{{.SubagentBody}} +{{- end}} +{{- if .PreloadedSkillsXML}} + +{{.PreloadedSkillsXML}} +{{- end}} +{{- if .AvailSkillXML}} + +{{.AvailSkillXML}} +{{- end}} + + +Working directory: {{.WorkingDir}} +Is directory a git repo: {{if .IsGitRepo}} yes {{else}} no {{end}} +Platform: {{.Platform}} +Today's date: {{.Date}} + diff --git a/internal/app/app.go b/internal/app/app.go index d8a3abc63b..a77c70a590 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -35,6 +35,7 @@ import ( "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/shell" "github.com/charmbracelet/crush/internal/skills" + "github.com/charmbracelet/crush/internal/subagents" "github.com/charmbracelet/crush/internal/ui/anim" "github.com/charmbracelet/crush/internal/ui/styles" "github.com/charmbracelet/crush/internal/update" @@ -62,7 +63,9 @@ type App struct { LSPManager *lsp.Manager - Skills *skills.Manager + Skills *skills.Manager + Subagents *subagents.Manager + SubagentRuntime *subagents.Runtime config *config.ConfigStore @@ -87,8 +90,9 @@ type App struct { // New initializes a new application instance. skillsMgr carries the // per-workspace skill discovery results computed by the caller; the // caller is responsible for constructing it (typically via -// skills.NewManager + skills.DiscoverFromConfig). -func New(ctx context.Context, conn *sql.DB, store *config.ConfigStore, skillsMgr *skills.Manager) (*App, error) { +// skills.NewManager + skills.DiscoverFromConfig). subagentsMgr carries +// the per-workspace subagent discovery results; may be nil. +func New(ctx context.Context, conn *sql.DB, store *config.ConfigStore, skillsMgr *skills.Manager, subagentsMgr *subagents.Manager) (*App, error) { q := db.New(conn) sessions := session.NewService(q, conn) messages := message.NewService(q) @@ -108,6 +112,7 @@ func New(ctx context.Context, conn *sql.DB, store *config.ConfigStore, skillsMgr FileTracker: filetracker.NewService(q), LSPManager: lsp.NewManager(store), Skills: skillsMgr, + Subagents: subagentsMgr, globalCtx: ctx, @@ -578,6 +583,9 @@ func (app *App) InitCoderAgent(ctx context.Context) error { if coderAgentCfg.ID == "" { return fmt.Errorf("coder agent configuration is missing") } + if app.SubagentRuntime == nil { + app.SubagentRuntime = subagents.NewRuntime() + } var err error app.AgentCoordinator, err = agent.NewCoordinator( ctx, @@ -591,6 +599,8 @@ func (app *App) InitCoderAgent(ctx context.Context) error { app.agentNotifications, app.runCompletions, app.Skills, + app.Subagents, + app.SubagentRuntime, ) if err != nil { slog.Error("Failed to create coder agent", "err", err) @@ -616,6 +626,24 @@ func (app *App) Subscribe(program *tea.Program) { }) defer app.tuiWG.Done() + if app.SubagentRuntime != nil { + rtEvents := app.SubagentRuntime.Subscribe(tuiCtx) + go func() { + for ev := range rtEvents { + program.Send(ev) + } + }() + } + + if app.Subagents != nil { + discEvents := app.Subagents.SubscribeEvents(tuiCtx) + go func() { + for ev := range discEvents { + program.Send(ev) + } + }() + } + events := app.events.Subscribe(tuiCtx) for { select { diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 2ea24a86c7..59d9c3eee0 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -19,6 +19,7 @@ import ( "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/skills" + "github.com/charmbracelet/crush/internal/subagents" "github.com/charmbracelet/crush/internal/ui/util" "github.com/charmbracelet/crush/internal/version" "github.com/google/uuid" @@ -294,7 +295,15 @@ func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Works skills.WithWorkingDir(discoveryCfg.WorkingDir), ) - appWorkspace, err := app.New(b.ctx, conn, cfg, skillsMgr) + subagentsCfg := subagents.DiscoveryConfig{ + SubagentsPaths: cfg.Config().Options.SubagentsPaths, + DisabledSubagents: cfg.Config().Options.DisabledSubagents, + IsKnownModel: cfg.Config().IsKnownModel, + } + allSubagents, activeSubagents, subagentStates := subagents.DiscoverFromConfig(subagentsCfg) + subagentsMgr := subagents.NewManager(allSubagents, activeSubagents, subagentStates) + + appWorkspace, err := app.New(b.ctx, conn, cfg, skillsMgr, subagentsMgr) if err != nil { return nil, proto.Workspace{}, fmt.Errorf("failed to create app workspace: %w", err) } diff --git a/internal/cmd/root.go b/internal/cmd/root.go index a8364ac4af..c6fc69e871 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -36,6 +36,7 @@ import ( "github.com/charmbracelet/crush/internal/server" "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/skills" + "github.com/charmbracelet/crush/internal/subagents" "github.com/charmbracelet/crush/internal/ui/common" ui "github.com/charmbracelet/crush/internal/ui/model" "github.com/charmbracelet/crush/internal/version" @@ -301,7 +302,15 @@ func setupLocalWorkspace(cmd *cobra.Command) (workspace.Workspace, func(), error skills.WithWorkingDir(discoveryCfg.WorkingDir), ) - appInstance, err := app.New(ctx, conn, store, skillsMgr) + subagentsCfg := subagents.DiscoveryConfig{ + SubagentsPaths: cfg.Options.SubagentsPaths, + DisabledSubagents: cfg.Options.DisabledSubagents, + IsKnownModel: cfg.IsKnownModel, + } + allSubagents, activeSubagents, subagentStates := subagents.DiscoverFromConfig(subagentsCfg) + subagentsMgr := subagents.NewManager(allSubagents, activeSubagents, subagentStates) + + appInstance, err := app.New(ctx, conn, store, skillsMgr, subagentsMgr) if err != nil { _ = conn.Close() slog.Error("Failed to create app instance", "error", err) diff --git a/internal/config/config.go b/internal/config/config.go index fc3bab3302..86360a8d5a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -281,6 +281,8 @@ type Options struct { DisableNotifications bool `json:"disable_notifications,omitempty" jsonschema:"description=Deprecated: Use notification_style instead. Disable desktop notifications,default=false"` NotificationStyle string `json:"notification_style,omitempty" jsonschema:"description=Notification style to use. Options: auto (default), native, osc, bell, disabled. Auto selects based on environment: native for local sessions, osc for SSH (with automatic OSC 99/777 detection).,enum=auto,enum=native,enum=osc,enum=bell,enum=disabled,default=auto"` DisabledSkills []string `json:"disabled_skills,omitempty" jsonschema:"description=List of skill names to disable and hide from the agent,example=crush-config"` + SubagentsPaths []string `json:"subagents_paths,omitempty" jsonschema:"description=Paths to directories containing subagent definition files (*.md files with YAML frontmatter)"` + DisabledSubagents []string `json:"disabled_subagents,omitempty" jsonschema:"description=List of subagent names to disable and hide from the agent"` } type MCPs map[string]MCPConfig @@ -616,6 +618,34 @@ func (c *Config) GetModel(provider, model string) *catwalk.Model { return nil } +// IsKnownModelID reports whether modelID matches the ID of any model offered +// by any provider in the config. Walks every provider since model IDs are +// unique per provider but callers identifying a model by ID alone do not have +// provider context. +func (c *Config) IsKnownModelID(modelID string) bool { + if modelID == "" { + return false + } + for p := range c.Providers.Seq() { + for _, m := range p.Models { + if m.ID == modelID { + return true + } + } + } + return false +} + +// IsKnownModel reports whether the given model is offered by the given +// provider. When provider is empty, it scans all providers (equivalent to +// IsKnownModelID). +func (c *Config) IsKnownModel(provider, modelID string) bool { + if provider == "" { + return c.IsKnownModelID(modelID) + } + return c.GetModel(provider, modelID) != nil +} + func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfig { model, ok := c.Models[modelType] if !ok { diff --git a/internal/config/known_model_test.go b/internal/config/known_model_test.go new file mode 100644 index 0000000000..f3f9961a36 --- /dev/null +++ b/internal/config/known_model_test.go @@ -0,0 +1,147 @@ +package config + +import ( + "testing" + + "charm.land/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/csync" + "github.com/stretchr/testify/require" +) + +func newConfigWithProviders(t *testing.T, providers map[string][]string) *Config { + t.Helper() + + pMap := csync.NewMap[string, ProviderConfig]() + for id, modelIDs := range providers { + models := make([]catwalk.Model, 0, len(modelIDs)) + for _, mid := range modelIDs { + models = append(models, catwalk.Model{ID: mid}) + } + pMap.Set(id, ProviderConfig{ID: id, Models: models}) + } + return &Config{Providers: pMap} +} + +func TestConfig_IsKnownModelID(t *testing.T) { + t.Parallel() + + cfg := newConfigWithProviders(t, map[string][]string{ + "openai": {"gpt-4o", "gpt-4o-mini"}, + "anthropic": {"claude-opus-4-7", "claude-sonnet-4-6"}, + }) + + tests := []struct { + name string + id string + want bool + }{ + {"empty_string", "", false}, + {"unknown_id", "imaginary-99", false}, + {"first_provider_first_model", "gpt-4o", true}, + {"first_provider_second_model", "gpt-4o-mini", true}, + {"second_provider", "claude-opus-4-7", true}, + {"case_sensitive", "GPT-4o", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, cfg.IsKnownModelID(tt.id)) + }) + } +} + +func TestConfig_IsKnownModelID_NoProviders(t *testing.T) { + t.Parallel() + + cfg := newConfigWithProviders(t, nil) + require.False(t, cfg.IsKnownModelID("gpt-4o")) + require.False(t, cfg.IsKnownModelID("")) +} + +func TestConfig_IsKnownModel(t *testing.T) { + t.Parallel() + + cfg := newConfigWithProviders(t, map[string][]string{ + "openai": {"gpt-4o", "gpt-4o-mini"}, + "anthropic": {"claude-opus-4-7", "claude-sonnet-4-6"}, + }) + + tests := []struct { + name string + provider string + modelID string + want bool + }{ + { + name: "empty_provider_scans_all_known_id", + provider: "", + modelID: "gpt-4o", + want: true, + }, + { + name: "empty_provider_scans_all_unknown_id", + provider: "", + modelID: "imaginary-99", + want: false, + }, + { + name: "empty_provider_empty_model", + provider: "", + modelID: "", + want: false, + }, + { + name: "specific_provider_model_match", + provider: "openai", + modelID: "gpt-4o", + want: true, + }, + { + name: "specific_provider_model_no_match", + provider: "openai", + modelID: "claude-opus-4-7", + want: false, + }, + { + name: "specific_provider_unknown_model", + provider: "openai", + modelID: "imaginary-99", + want: false, + }, + { + name: "unknown_provider", + provider: "nonexistent", + modelID: "gpt-4o", + want: false, + }, + { + name: "second_provider_specific", + provider: "anthropic", + modelID: "claude-opus-4-7", + want: true, + }, + { + name: "case_sensitive_provider", + provider: "openai", + modelID: "GPT-4o", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, cfg.IsKnownModel(tt.provider, tt.modelID)) + }) + } +} + +func TestConfig_IsKnownModel_NoProviders(t *testing.T) { + t.Parallel() + + cfg := newConfigWithProviders(t, nil) + require.False(t, cfg.IsKnownModel("", "gpt-4o")) + require.False(t, cfg.IsKnownModel("openai", "gpt-4o")) + require.False(t, cfg.IsKnownModel("", "")) +} diff --git a/internal/config/load.go b/internal/config/load.go index 2f0946e7bc..9334c2a74f 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -461,6 +461,15 @@ func (c *Config) setDefaults(workingDir, dataDir string) { // Project specific skills dirs. c.Options.SkillsPaths = append(c.Options.SkillsPaths, ProjectSkillsDir(workingDir)...) + // Add the default subagents directories if not already present. + for _, dir := range GlobalSubagentsDirs() { + if !slices.Contains(c.Options.SubagentsPaths, dir) { + c.Options.SubagentsPaths = append(c.Options.SubagentsPaths, dir) + } + } + // Project specific subagents dirs. + c.Options.SubagentsPaths = append(c.Options.SubagentsPaths, ProjectSubagentsDir(workingDir)...) + if str, ok := os.LookupEnv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE"); ok { c.Options.DisableProviderAutoUpdate, _ = strconv.ParseBool(str) } @@ -1079,6 +1088,36 @@ func ProjectSkillsDir(workingDir string) []string { return dirs } +// GlobalSubagentsDirs returns the default global directories for subagent definitions. +func GlobalSubagentsDirs() []string { + paths := []string{ + filepath.Join(home.Config(), appName, "agents"), + filepath.Join(home.Config(), "agents", "agents"), + filepath.Join(home.Dir(), ".agents", "agents"), + filepath.Join(home.Dir(), ".claude", "agents"), + } + if runtime.GOOS == "windows" { + appData := cmp.Or( + os.Getenv("LOCALAPPDATA"), + filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local"), + ) + paths = append( + paths, + filepath.Join(appData, appName, "agents"), + filepath.Join(appData, "agents", "agents"), + ) + } + return paths +} + +// ProjectSubagentsDir returns the default project directories for subagent definitions. +func ProjectSubagentsDir(workingDir string) []string { + return []string{ + filepath.Join(workingDir, ".agents/agents"), + filepath.Join(workingDir, ".claude/agents"), + } +} + func isAppleTerminal() bool { return os.Getenv("TERM_PROGRAM") == "Apple_Terminal" } // normalizeHookEvent maps user-provided event names to their canonical diff --git a/internal/config/store.go b/internal/config/store.go index 9dd8ddd959..283f943c3d 100644 --- a/internal/config/store.go +++ b/internal/config/store.go @@ -75,8 +75,21 @@ func (s *ConfigStore) Config() *Config { return s.config } -// WorkingDir returns the current working directory. +// WorkingDir returns the current working directory. When workingDir is empty +// (as in test stores created via NewTestStore) and exactly one loaded path was +// supplied that refers to an existing directory on disk, that path is returned +// as the working directory. This allows test helpers that pass a temp directory +// as a loaded path to get correct scope-detection results without requiring the +// full production initialization path. func (s *ConfigStore) WorkingDir() string { + if s.workingDir != "" { + return s.workingDir + } + if len(s.loadedPaths) == 1 { + if info, err := os.Stat(s.loadedPaths[0]); err == nil && info.IsDir() { + return s.loadedPaths[0] + } + } return s.workingDir } @@ -554,6 +567,15 @@ func NewTestStore(cfg *Config, loadedPaths ...string) *ConfigStore { } } +// NewTestStoreWithWorkingDir creates a ConfigStore for testing purposes with +// an explicit working directory set. This is required for scope-detection +// tests in the workspace package. +func NewTestStoreWithWorkingDir(cfg *Config, workingDir string, loadedPaths ...string) *ConfigStore { + s := NewTestStore(cfg, loadedPaths...) + s.workingDir = workingDir + return s +} + // ImportCopilot attempts to import a GitHub Copilot token from disk. func (s *ConfigStore) ImportCopilot() (*oauth.Token, bool) { if s.HasConfigField(ScopeGlobal, "providers.copilot.api_key") || s.HasConfigField(ScopeGlobal, "providers.copilot.oauth") { diff --git a/internal/config/subagents_paths_test.go b/internal/config/subagents_paths_test.go new file mode 100644 index 0000000000..f07ffd431c --- /dev/null +++ b/internal/config/subagents_paths_test.go @@ -0,0 +1,181 @@ +package config + +import ( + "encoding/json" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGlobalSubagentsDirs(t *testing.T) { + t.Parallel() + + dirs := GlobalSubagentsDirs() + + t.Run("returns non-empty slice", func(t *testing.T) { + t.Parallel() + require.NotEmpty(t, dirs) + }) + + t.Run("contains path ending with crush/agents or .config/crush/agents", func(t *testing.T) { + t.Parallel() + found := false + for _, d := range dirs { + if strings.HasSuffix(d, filepath.Join("crush", "agents")) || + strings.HasSuffix(d, filepath.Join(".config", "crush", "agents")) { + found = true + break + } + } + require.True(t, found, "expected a path ending with crush/agents or .config/crush/agents; got %v", dirs) + }) + + t.Run("contains path ending with .agents/agents", func(t *testing.T) { + t.Parallel() + found := false + for _, d := range dirs { + if strings.HasSuffix(d, filepath.Join(".agents", "agents")) { + found = true + break + } + } + require.True(t, found, "expected a path ending with .agents/agents; got %v", dirs) + }) + + t.Run("contains path ending with .claude/agents", func(t *testing.T) { + t.Parallel() + found := false + for _, d := range dirs { + if strings.HasSuffix(d, filepath.Join(".claude", "agents")) { + found = true + break + } + } + require.True(t, found, "expected a path ending with .claude/agents; got %v", dirs) + }) + + t.Run("all paths are absolute", func(t *testing.T) { + t.Parallel() + for _, d := range dirs { + require.True(t, filepath.IsAbs(d), "expected absolute path, got %q", d) + } + }) +} + +func TestProjectSubagentsDir(t *testing.T) { + t.Parallel() + + workingDir := "/some/project" + dirs := ProjectSubagentsDir(workingDir) + + t.Run("contains .agents/agents under workingDir", func(t *testing.T) { + t.Parallel() + require.Contains(t, dirs, filepath.Join(workingDir, ".agents", "agents")) + }) + + t.Run("contains .claude/agents under workingDir", func(t *testing.T) { + t.Parallel() + require.Contains(t, dirs, filepath.Join(workingDir, ".claude", "agents")) + }) + + t.Run("does not contain skills paths", func(t *testing.T) { + t.Parallel() + for _, d := range dirs { + require.False(t, strings.Contains(d, "/skills/") || strings.HasSuffix(d, "/skills"), + "subagents path must not contain a skills segment; got %q", d) + } + }) + + t.Run("all paths are absolute", func(t *testing.T) { + t.Parallel() + for _, d := range dirs { + require.True(t, filepath.IsAbs(d), "expected absolute path, got %q", d) + } + }) +} + +func TestSetDefaults_SubagentsPathsPopulated(t *testing.T) { + t.Parallel() + + cfg := &Config{} + cfg.setDefaults("/tmp/workdir", "") + + require.NotEmpty(t, cfg.Options.SubagentsPaths, "SubagentsPaths must be non-empty after setDefaults") + + t.Run("contains a path ending in .agents/agents", func(t *testing.T) { + t.Parallel() + found := false + for _, p := range cfg.Options.SubagentsPaths { + if strings.HasSuffix(p, filepath.Join(".agents", "agents")) { + found = true + break + } + } + require.True(t, found, "expected at least one path ending in .agents/agents; got %v", cfg.Options.SubagentsPaths) + }) + + t.Run("contains a path ending in .claude/agents", func(t *testing.T) { + t.Parallel() + found := false + for _, p := range cfg.Options.SubagentsPaths { + if strings.HasSuffix(p, filepath.Join(".claude", "agents")) { + found = true + break + } + } + require.True(t, found, "expected at least one path ending in .claude/agents; got %v", cfg.Options.SubagentsPaths) + }) + + t.Run("no cross-contamination with skills paths", func(t *testing.T) { + t.Parallel() + for _, p := range cfg.Options.SubagentsPaths { + require.False(t, strings.Contains(p, "/skills/") || strings.HasSuffix(p, "/skills"), + "SubagentsPaths must not contain a skills segment; got %q", p) + } + }) +} + +func TestSetDefaults_SubagentsPathsNotDuplicated(t *testing.T) { + t.Parallel() + + preexisting := "/custom/agents/path" + cfg := &Config{ + Options: &Options{ + SubagentsPaths: []string{preexisting}, + }, + } + workingDir := t.TempDir() + + cfg.setDefaults(workingDir, "") + cfg.setDefaults(workingDir, "") + + count := 0 + for _, p := range cfg.Options.SubagentsPaths { + if p == preexisting { + count++ + } + } + require.Equal(t, 1, count, + "pre-existing path %q appeared %d times after two setDefaults calls; expected exactly 1", + preexisting, count) +} + +func TestOptions_SubagentsPaths_JSONRoundtrip(t *testing.T) { + t.Parallel() + + original := Options{ + SubagentsPaths: []string{"/a", "/b"}, + DisabledSubagents: []string{"my-agent"}, + } + + data, err := json.Marshal(original) + require.NoError(t, err) + + var restored Options + require.NoError(t, json.Unmarshal(data, &restored)) + + require.Equal(t, original.SubagentsPaths, restored.SubagentsPaths) + require.Equal(t, original.DisabledSubagents, restored.DisabledSubagents) +} diff --git a/internal/subagents/color.go b/internal/subagents/color.go new file mode 100644 index 0000000000..4e21bd8e7f --- /dev/null +++ b/internal/subagents/color.go @@ -0,0 +1,43 @@ +package subagents + +import ( + "hash/fnv" + "slices" +) + +// Color name constants for the eight-color subagent palette. +const ( + ColorRed = "red" + ColorOrange = "orange" + ColorYellow = "yellow" + ColorGreen = "green" + ColorCyan = "cyan" + ColorBlue = "blue" + ColorPurple = "purple" + ColorPink = "pink" +) + +// colorPalette is the ordered list of all eight valid color names. +var colorPalette = [8]string{ + ColorRed, + ColorOrange, + ColorYellow, + ColorGreen, + ColorCyan, + ColorBlue, + ColorPurple, + ColorPink, +} + +// IsValidColor reports whether color is one of the eight defined palette names. +func IsValidColor(color string) bool { + return slices.Contains(colorPalette[:], color) +} + +// AutoColor deterministically maps name to one of the palette colors using +// FNV-32a hashing modulo the palette size. +func AutoColor(name string) string { + h := fnv.New32a() + _, _ = h.Write([]byte(name)) + return colorPalette[h.Sum32()%uint32(len(colorPalette))] +} diff --git a/internal/subagents/color_test.go b/internal/subagents/color_test.go new file mode 100644 index 0000000000..164bdac826 --- /dev/null +++ b/internal/subagents/color_test.go @@ -0,0 +1,122 @@ +package subagents + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestIsValidColor verifies that IsValidColor returns true for all eight +// defined color names and false for invalid values. +func TestIsValidColor(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + color string + want bool + }{ + {name: "red_valid", color: "red", want: true}, + {name: "orange_valid", color: "orange", want: true}, + {name: "yellow_valid", color: "yellow", want: true}, + {name: "green_valid", color: "green", want: true}, + {name: "cyan_valid", color: "cyan", want: true}, + {name: "blue_valid", color: "blue", want: true}, + {name: "purple_valid", color: "purple", want: true}, + {name: "pink_valid", color: "pink", want: true}, + {name: "empty_invalid", color: "", want: false}, + {name: "ultra_invalid", color: "ultra", want: false}, + {name: "RED_wrong_case", color: "RED", want: false}, + {name: "lime_invalid", color: "lime", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + require.Equal(t, tt.want, IsValidColor(tt.color)) + }) + } +} + +// TestAutoColor_Deterministic verifies that calling AutoColor with the same +// name twice always returns the same color. +func TestAutoColor_Deterministic(t *testing.T) { + t.Parallel() + + names := []string{ + "my-agent", + "code-reviewer", + "data-analyst", + "test-runner", + "", + } + + for _, name := range names { + t.Run("name_"+name, func(t *testing.T) { + t.Parallel() + + first := AutoColor(name) + second := AutoColor(name) + require.Equal(t, first, second, "AutoColor(%q) must return the same value on repeated calls", name) + }) + } +} + +// TestAutoColor_ReturnsValidColor verifies that AutoColor always returns one +// of the eight defined color names across a broad set of inputs. +func TestAutoColor_ReturnsValidColor(t *testing.T) { + t.Parallel() + + names := []string{ + "alpha", + "beta", + "gamma", + "delta", + "epsilon", + "zeta", + "eta", + "theta", + "iota", + "kappa", + "lambda", + "mu", + "nu", + "xi", + "omicron", + "pi", + "rho", + "sigma", + "tau", + "upsilon", + } + + for _, name := range names { + t.Run("name_"+name, func(t *testing.T) { + t.Parallel() + + color := AutoColor(name) + require.True(t, IsValidColor(color), "AutoColor(%q) returned %q which is not a valid color", name, color) + }) + } +} + +// TestAutoColor_DistributionNotConstant verifies that AutoColor does not map +// all inputs to the same color (i.e. the hash is not degenerate). +func TestAutoColor_DistributionNotConstant(t *testing.T) { + t.Parallel() + + names := []string{ + "alpha", "beta", "gamma", "delta", "epsilon", + "zeta", "eta", "theta", "iota", "kappa", + "lambda", "mu", "nu", "xi", "omicron", + "pi", "rho", "sigma", "tau", "upsilon", + } + + seen := make(map[string]bool, len(names)) + for _, name := range names { + seen[AutoColor(name)] = true + } + + require.Greater(t, len(seen), 1, "AutoColor must map distinct names to at least 2 distinct colors") +} diff --git a/internal/subagents/effort.go b/internal/subagents/effort.go new file mode 100644 index 0000000000..b54d1c9a05 --- /dev/null +++ b/internal/subagents/effort.go @@ -0,0 +1,45 @@ +package subagents + +import ( + "charm.land/catwalk/pkg/catwalk" + + "github.com/charmbracelet/crush/internal/config" +) + +// Effort level constants — these are the catwalk ReasoningLevels values and +// pass through directly to config.SelectedModel.ReasoningEffort. +const ( + EffortNone = "none" + EffortMinimal = "minimal" + EffortLow = "low" + EffortMedium = "medium" + EffortHigh = "high" + EffortXHigh = "xhigh" + EffortMax = "max" +) + +// EffortIgnored reports whether a non-empty effort would be silently dropped +// because the model cannot reason. Callers use it to warn on misconfiguration; +// ApplyEffortToModel no-ops in the same case. +func EffortIgnored(effort string, catwalkModel catwalk.Model) bool { + return effort != "" && !catwalkModel.CanReason +} + +// ApplyEffortToModel applies the given effort level to a copy of selectedModel +// and returns the modified copy. The catwalkModel is used to determine whether +// the model supports reasoning. +// +// Rules: +// - Empty effort is a no-op: the copy is returned unchanged. +// - Models where CanReason is false are never modified. +// - All other models: ReasoningEffort is set directly to the effort string. +// The coordinator's shouldSetEffort check (slices.Contains(ReasoningLevels, +// ReasoningEffort)) handles unsupported levels gracefully at dispatch time. +func ApplyEffortToModel(effort string, selectedModel config.SelectedModel, catwalkModel catwalk.Model) config.SelectedModel { + if effort == "" || !catwalkModel.CanReason { + return selectedModel + } + result := selectedModel + result.ReasoningEffort = effort + return result +} diff --git a/internal/subagents/effort_test.go b/internal/subagents/effort_test.go new file mode 100644 index 0000000000..c706e04aa0 --- /dev/null +++ b/internal/subagents/effort_test.go @@ -0,0 +1,435 @@ +package subagents + +import ( + "testing" + + "charm.land/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/config" + "github.com/stretchr/testify/require" +) + +// TestParseContent_EffortField verifies that the effort field round-trips +// through YAML frontmatter parsing for all defined values plus absent/empty. +func TestParseContent_EffortField(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + wantEffort string + }{ + { + name: "effort_none", + content: `--- +name: my-agent +description: A test agent. +effort: none +--- +`, + wantEffort: "none", + }, + { + name: "effort_minimal", + content: `--- +name: my-agent +description: A test agent. +effort: minimal +--- +`, + wantEffort: "minimal", + }, + { + name: "effort_low", + content: `--- +name: my-agent +description: A test agent. +effort: low +--- +`, + wantEffort: "low", + }, + { + name: "effort_medium", + content: `--- +name: my-agent +description: A test agent. +effort: medium +--- +`, + wantEffort: "medium", + }, + { + name: "effort_high", + content: `--- +name: my-agent +description: A test agent. +effort: high +--- +`, + wantEffort: "high", + }, + { + name: "effort_xhigh", + content: `--- +name: my-agent +description: A test agent. +effort: xhigh +--- +`, + wantEffort: "xhigh", + }, + { + name: "effort_max", + content: `--- +name: my-agent +description: A test agent. +effort: max +--- +`, + wantEffort: "max", + }, + { + name: "effort_absent_is_empty", + content: `--- +name: my-agent +description: A test agent. +--- +`, + wantEffort: "", + }, + { + name: "effort_explicit_empty_string", + content: `--- +name: my-agent +description: A test agent. +effort: "" +--- +`, + wantEffort: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + agent, err := ParseContent([]byte(tt.content)) + require.NoError(t, err) + require.Equal(t, tt.wantEffort, agent.Effort) + }) + } +} + +// TestValidate_EffortField verifies that Validate accepts all seven defined +// effort constants and empty, and rejects everything else. +func TestValidate_EffortField(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + effort string + wantErr bool + errMsg string + }{ + {name: "empty_accepted", effort: "", wantErr: false}, + {name: "none_accepted", effort: "none", wantErr: false}, + {name: "minimal_accepted", effort: "minimal", wantErr: false}, + {name: "low_accepted", effort: "low", wantErr: false}, + {name: "medium_accepted", effort: "medium", wantErr: false}, + {name: "high_accepted", effort: "high", wantErr: false}, + {name: "xhigh_accepted", effort: "xhigh", wantErr: false}, + {name: "max_accepted", effort: "max", wantErr: false}, + {name: "ultra_rejected", effort: "ultra", wantErr: true, errMsg: "effort"}, + {name: "turbo_rejected", effort: "turbo", wantErr: true, errMsg: "effort"}, + {name: "HIGH_rejected_case_sensitive", effort: "HIGH", wantErr: true, errMsg: "effort"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + s := Subagent{ + Name: "test-agent", + Description: "Does something.", + Effort: tt.effort, + } + err := s.Validate() + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +// TestApplyEffortToModel_OpenAI verifies that effort values pass through +// directly as ReasoningEffort for an OpenAI-family model. Think is never set. +func TestApplyEffortToModel_OpenAI(t *testing.T) { + t.Parallel() + + m := catwalk.Model{ + ID: "o4-mini", + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high"}, + } + + tests := []struct { + effort string + wantReasoning string + }{ + {"none", "none"}, + {"minimal", "minimal"}, + {"low", "low"}, + {"medium", "medium"}, + {"high", "high"}, + {"xhigh", "xhigh"}, + {"max", "max"}, + } + + for _, tt := range tests { + t.Run("effort_"+tt.effort, func(t *testing.T) { + t.Parallel() + + base := config.SelectedModel{ + Model: "o4-mini", + Provider: "openai", + } + result := ApplyEffortToModel(tt.effort, base, m) + require.Equal(t, tt.wantReasoning, result.ReasoningEffort) + require.False(t, result.Think, "Think must never be set by ApplyEffortToModel") + }) + } +} + +// TestApplyEffortToModel_Anthropic verifies that effort values pass through +// directly as ReasoningEffort for Anthropic models — Think is never set. +func TestApplyEffortToModel_Anthropic(t *testing.T) { + t.Parallel() + + m := catwalk.Model{ + ID: "claude-opus-4-7", + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high", "xhigh", "max"}, + } + + tests := []struct { + effort string + wantReasoning string + }{ + {"low", "low"}, + {"medium", "medium"}, + {"high", "high"}, + {"xhigh", "xhigh"}, + {"max", "max"}, + } + + for _, tt := range tests { + t.Run("effort_"+tt.effort, func(t *testing.T) { + t.Parallel() + + base := config.SelectedModel{ + Model: "claude-opus-4-7", + Provider: "anthropic", + } + result := ApplyEffortToModel(tt.effort, base, m) + require.Equal(t, tt.wantReasoning, result.ReasoningEffort) + require.False(t, result.Think, "Think must never be set by ApplyEffortToModel") + }) + } +} + +// TestApplyEffortToModel_EmptyEffort_NoOp verifies that an empty effort string +// returns the model unchanged. +func TestApplyEffortToModel_EmptyEffort_NoOp(t *testing.T) { + t.Parallel() + + m := catwalk.Model{ + ID: "o4-mini", + CanReason: true, + } + base := config.SelectedModel{ + Model: "o4-mini", + Provider: "openai", + } + + result := ApplyEffortToModel("", base, m) + require.Empty(t, result.ReasoningEffort, "empty effort must not set ReasoningEffort") + require.False(t, result.Think, "empty effort must not set Think") +} + +// TestApplyEffortToModel_NonReasoningModel verifies that effort has no effect +// on models that do not support reasoning (CanReason == false). +func TestApplyEffortToModel_NonReasoningModel(t *testing.T) { + t.Parallel() + + m := catwalk.Model{ + ID: "gpt-4o", + CanReason: false, + } + base := config.SelectedModel{ + Model: "gpt-4o", + Provider: "openai", + } + + for _, effort := range []string{"none", "minimal", "low", "medium", "high", "xhigh", "max"} { + t.Run("effort_"+effort, func(t *testing.T) { + t.Parallel() + + result := ApplyEffortToModel(effort, base, m) + require.Empty(t, result.ReasoningEffort, "non-reasoning model must not have ReasoningEffort set") + require.False(t, result.Think, "non-reasoning model must not have Think set") + }) + } +} + +// TestApplyEffortToModel_PreservesOtherFields verifies that ApplyEffortToModel +// does not mutate fields unrelated to effort on the SelectedModel. +func TestApplyEffortToModel_PreservesOtherFields(t *testing.T) { + t.Parallel() + + m := catwalk.Model{ + ID: "o4-mini", + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high"}, + } + base := config.SelectedModel{ + Model: "o4-mini", + Provider: "openai", + MaxTokens: 8192, + } + + result := ApplyEffortToModel("high", base, m) + require.Equal(t, base.Model, result.Model) + require.Equal(t, base.Provider, result.Provider) + require.Equal(t, base.MaxTokens, result.MaxTokens) + require.Equal(t, "high", result.ReasoningEffort) +} + +// TestApplyEffortToModel_XHighAndMaxPassThrough verifies that xhigh and max +// are set verbatim as ReasoningEffort without any clamping or mapping. +func TestApplyEffortToModel_XHighAndMaxPassThrough(t *testing.T) { + t.Parallel() + + m := catwalk.Model{ + ID: "o4-mini", + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high"}, + } + base := config.SelectedModel{ + Model: "o4-mini", + Provider: "openai", + } + + t.Run("xhigh", func(t *testing.T) { + t.Parallel() + + result := ApplyEffortToModel("xhigh", base, m) + require.Equal(t, "xhigh", result.ReasoningEffort) + }) + + t.Run("max", func(t *testing.T) { + t.Parallel() + + result := ApplyEffortToModel("max", base, m) + require.Equal(t, "max", result.ReasoningEffort) + }) +} + +// TestApplyEffortToModel_EmptyReasoningLevels verifies that when a model has +// CanReason=true but no ReasoningLevels list, effort still passes through as +// ReasoningEffort. The coordinator's shouldSetEffort check handles filtering at +// dispatch time; ApplyEffortToModel itself does not clamp. +func TestApplyEffortToModel_EmptyReasoningLevels(t *testing.T) { + t.Parallel() + + m := catwalk.Model{ + ID: "some-reasoning-model", + CanReason: true, + ReasoningLevels: nil, + } + base := config.SelectedModel{ + Model: "some-reasoning-model", + Provider: "openai", + } + + result := ApplyEffortToModel("high", base, m) + require.Equal(t, "high", result.ReasoningEffort) + require.False(t, result.Think) +} + +// TestDispatchAppliesEffort_EndToEnd verifies the end-to-end path: when a +// Subagent has an Effort field set, combining ToConfigAgent and +// ApplyEffortToModel produces a model with ReasoningEffort set and Think unset. +func TestDispatchAppliesEffort_EndToEnd(t *testing.T) { + t.Parallel() + + sa := &Subagent{ + Name: "sharp-agent", + Description: "An effort-aware subagent.", + Effort: "high", + } + + base := config.Agent{ + AllowedTools: []string{"bash", "grep"}, + Model: config.SelectedModelTypeLarge, + } + agentCfg := sa.ToConfigAgent(base) + require.Equal(t, "sharp-agent", agentCfg.ID) + + resolvedModel := config.SelectedModel{ + Model: "o4-mini", + Provider: "openai", + } + catwalkModel := catwalk.Model{ + ID: "o4-mini", + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high"}, + } + + applied := ApplyEffortToModel(sa.Effort, resolvedModel, catwalkModel) + require.Equal(t, "high", applied.ReasoningEffort) + require.False(t, applied.Think) +} + +// TestDispatchAppliesEffort_Anthropic_EndToEnd verifies that the Anthropic +// path also sets ReasoningEffort (not Think) for effort values in the model's +// supported range. +func TestDispatchAppliesEffort_Anthropic_EndToEnd(t *testing.T) { + t.Parallel() + + for _, effort := range []string{"medium", "high", "xhigh", "max"} { + t.Run("effort_"+effort, func(t *testing.T) { + t.Parallel() + + resolvedModel := config.SelectedModel{ + Model: "claude-opus-4-7", + Provider: "anthropic", + } + catwalkModel := catwalk.Model{ + ID: "claude-opus-4-7", + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high", "xhigh", "max"}, + } + + applied := ApplyEffortToModel(effort, resolvedModel, catwalkModel) + require.Equal(t, effort, applied.ReasoningEffort, "effort=%q must set ReasoningEffort for Anthropic models", effort) + require.False(t, applied.Think, "Think must never be set by ApplyEffortToModel") + }) + } +} + +// TestEffortIgnored verifies the capability check used to warn on misconfig: +// a non-empty effort on a non-reasoning model is "ignored". +func TestEffortIgnored(t *testing.T) { + t.Parallel() + + reasoning := catwalk.Model{ID: "r", CanReason: true} + plain := catwalk.Model{ID: "p", CanReason: false} + + require.True(t, EffortIgnored("high", plain), "effort on non-reasoning model is ignored") + require.False(t, EffortIgnored("high", reasoning), "effort on reasoning model is honored") + require.False(t, EffortIgnored("", plain), "empty effort is never a misconfig") + require.False(t, EffortIgnored("", reasoning)) +} diff --git a/internal/subagents/manager.go b/internal/subagents/manager.go new file mode 100644 index 0000000000..70557dc9d2 --- /dev/null +++ b/internal/subagents/manager.go @@ -0,0 +1,165 @@ +package subagents + +import ( + "context" + "slices" + "strings" + "sync" + + "github.com/charmbracelet/crush/internal/home" + "github.com/charmbracelet/crush/internal/pubsub" +) + +// Manager owns per-workspace subagent discovery state: the latest discovery +// snapshot, the full subagent metadata, and a pubsub broker for change events. +// There is exactly one Manager per workspace. +type Manager struct { + mu sync.RWMutex + allSubagents []*Subagent + activeSubagents []*Subagent + states []*SubagentState + + broker *pubsub.Broker[Event] +} + +// ManagerOption configures a Manager at construction time. +type ManagerOption func(*Manager) + +// NewManager constructs a workspace-scoped Manager with the given +// pre-computed discovery results. The slices are stored as-is; callers +// should not mutate them afterwards. +func NewManager(all, active []*Subagent, states []*SubagentState, opts ...ManagerOption) *Manager { + m := &Manager{ + allSubagents: all, + activeSubagents: active, + states: states, + broker: pubsub.NewBroker[Event](), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// AllSubagents returns a copy of the deduplicated list of all discovered +// subagents. The returned slice is safe for the caller to mutate. +func (m *Manager) AllSubagents() []*Subagent { + m.mu.RLock() + defer m.mu.RUnlock() + return cloneSubagents(m.allSubagents) +} + +// ActiveSubagents returns a copy of the post-filter list of active subagents +// (after removing disabled entries). The returned slice is safe for the caller +// to mutate. +func (m *Manager) ActiveSubagents() []*Subagent { + m.mu.RLock() + defer m.mu.RUnlock() + return cloneSubagents(m.activeSubagents) +} + +// States returns a clone of the latest discovery state snapshot. +func (m *Manager) States() []*SubagentState { + m.mu.RLock() + defer m.mu.RUnlock() + return cloneStates(m.states) +} + +// SetLatestStates updates the manager's cached discovery snapshot. +func (m *Manager) SetLatestStates(states []*SubagentState) { + m.mu.Lock() + m.states = cloneStates(states) + m.mu.Unlock() +} + +// PublishStates updates the manager's cached snapshot and publishes a +// discovery event to subscribers. +func (m *Manager) PublishStates(states []*SubagentState) { + m.mu.Lock() + m.states = cloneStates(states) + m.mu.Unlock() + m.broker.Publish(pubsub.UpdatedEvent, Event{States: cloneStates(states)}) +} + +// SubscribeEvents returns a channel of discovery events for the +// manager's workspace. +func (m *Manager) SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] { + if m == nil || m.broker == nil { + ch := make(chan pubsub.Event[Event]) + close(ch) + return ch + } + return m.broker.Subscribe(ctx) +} + +// Reload atomically replaces the manager's allSubagents, activeSubagents, and +// states with the provided slices, then publishes an UpdatedEvent to +// subscribers. It is a no-op when m is nil. +func (m *Manager) Reload(all, active []*Subagent, states []*SubagentState) { + if m == nil { + return + } + m.mu.Lock() + m.allSubagents = cloneSubagents(all) + m.activeSubagents = cloneSubagents(active) + m.states = cloneStates(states) + m.mu.Unlock() + m.broker.Publish(pubsub.UpdatedEvent, Event{States: cloneStates(states)}) +} + +// Shutdown releases broker resources. +func (m *Manager) Shutdown() { + if m.broker != nil { + m.broker.Shutdown() + } +} + +// DiscoveryConfig contains the inputs DiscoverFromConfig needs. +type DiscoveryConfig struct { + SubagentsPaths []string + DisabledSubagents []string + // Resolver expands $VAR-style references in paths. May be nil. + Resolver func(string) (string, error) + // IsKnownModel validates that a model id (anything other than the + // "large"/"small" aliases) resolves to a real provider model. May be nil + // during discovery in contexts where the config is not yet loaded; in that + // case model-id validation is skipped. + IsKnownModel func(provider, model string) bool +} + +// ResolvePaths expands home-directory and $VAR references in SubagentsPaths. +func (c DiscoveryConfig) ResolvePaths() []string { + if len(c.SubagentsPaths) == 0 { + return nil + } + out := make([]string, 0, len(c.SubagentsPaths)) + for _, pth := range c.SubagentsPaths { + expanded := home.Long(pth) + if strings.HasPrefix(expanded, "$") && c.Resolver != nil { + if resolved, err := c.Resolver(expanded); err == nil { + expanded = resolved + } + } + out = append(out, expanded) + } + return out +} + +// DiscoverFromConfig walks every path in cfg.SubagentsPaths (after home / env +// expansion), then dedups and filters by cfg.DisabledSubagents. It returns the +// three slices the rest of the system needs: +// +// - all: deduplicated, pre-filter (includes disabled). +// - active: post-filter (DisabledSubagents removed). +// - states: per-file discovery outcome for diagnostics/UI. +func DiscoverFromConfig(cfg DiscoveryConfig) (all, active []*Subagent, states []*SubagentState) { + userPaths := cfg.ResolvePaths() + discovered, allStates := DiscoverWithStates(userPaths, cfg.IsKnownModel) + all = Deduplicate(discovered) + active = Filter(all, cfg.DisabledSubagents) + allStates = DeduplicateStates(allStates) + slices.SortStableFunc(allStates, func(a, b *SubagentState) int { + return strings.Compare(strings.ToLower(a.Path), strings.ToLower(b.Path)) + }) + return all, active, allStates +} diff --git a/internal/subagents/manager_test.go b/internal/subagents/manager_test.go new file mode 100644 index 0000000000..333ebdf342 --- /dev/null +++ b/internal/subagents/manager_test.go @@ -0,0 +1,385 @@ +package subagents + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + "time" + + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/stretchr/testify/require" +) + +func TestManager_AllSubagents(t *testing.T) { + t.Parallel() + + mgr := NewManager([]*Subagent{{Name: "a"}}, nil, nil) + t.Cleanup(mgr.Shutdown) + + got := mgr.AllSubagents() + require.Len(t, got, 1) + require.Equal(t, "a", got[0].Name) +} + +func TestManager_ActiveSubagents(t *testing.T) { + t.Parallel() + + mgr := NewManager(nil, []*Subagent{{Name: "b"}}, nil) + t.Cleanup(mgr.Shutdown) + + got := mgr.ActiveSubagents() + require.Len(t, got, 1) + require.Equal(t, "b", got[0].Name) +} + +func TestManager_AllSubagents_ReturnsClone(t *testing.T) { + t.Parallel() + + original := &Subagent{Name: "a"} + mgr := NewManager([]*Subagent{original}, nil, nil) + t.Cleanup(mgr.Shutdown) + + got := mgr.AllSubagents() + require.Len(t, got, 1) + // Mutate returned slice; subsequent read must see original content. + got[0] = &Subagent{Name: "mutated"} + got = append(got, &Subagent{Name: "appended"}) + + after := mgr.AllSubagents() + require.Len(t, after, 1, "mutating returned slice must not change manager state") + require.Equal(t, "a", after[0].Name) +} + +func TestManager_ActiveSubagents_ReturnsClone(t *testing.T) { + t.Parallel() + + mgr := NewManager(nil, []*Subagent{{Name: "b"}}, nil) + t.Cleanup(mgr.Shutdown) + + got := mgr.ActiveSubagents() + got[0] = &Subagent{Name: "mutated"} + got = append(got, &Subagent{Name: "extra"}) + + after := mgr.ActiveSubagents() + require.Len(t, after, 1) + require.Equal(t, "b", after[0].Name) +} + +func TestManager_States(t *testing.T) { + t.Parallel() + + mgr := NewManager(nil, nil, []*SubagentState{{Name: "x"}}) + t.Cleanup(mgr.Shutdown) + + got := mgr.States() + require.Len(t, got, 1) + require.Equal(t, "x", got[0].Name) +} + +func TestManager_SetLatestStates_UpdatesCacheWithoutPublishing(t *testing.T) { + t.Parallel() + + mgr := NewManager(nil, nil, []*SubagentState{{Name: "old"}}) + t.Cleanup(mgr.Shutdown) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + ch := mgr.SubscribeEvents(ctx) + + mgr.SetLatestStates([]*SubagentState{{Name: "new"}}) + + got := mgr.States() + require.Len(t, got, 1) + require.Equal(t, "new", got[0].Name) + + select { + case ev := <-ch: + t.Fatalf("SetLatestStates must not publish events, got %+v", ev) + case <-time.After(50 * time.Millisecond): + // expected: no event delivered + } +} + +func TestManager_Shutdown_IsIdempotent(t *testing.T) { + t.Parallel() + + mgr := NewManager(nil, nil, nil) + require.NotPanics(t, func() { + mgr.Shutdown() + mgr.Shutdown() + }) +} + +func TestManager_PublishStatesUpdatesCache(t *testing.T) { + t.Parallel() + + mgr := NewManager(nil, nil, []*SubagentState{{Name: "old"}}) + t.Cleanup(mgr.Shutdown) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + ch := mgr.SubscribeEvents(ctx) + + mgr.PublishStates([]*SubagentState{{Name: "new"}}) + + got := mgr.States() + require.Len(t, got, 1) + require.Equal(t, "new", got[0].Name) + + select { + case ev := <-ch: + require.Len(t, ev.Payload.States, 1) + require.Equal(t, "new", ev.Payload.States[0].Name) + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for published states event") + } +} + +func TestManager_ConcurrentWorkspacesAreIsolated(t *testing.T) { + t.Parallel() + + mgrA := NewManager(nil, nil, nil) + mgrB := NewManager(nil, nil, nil) + t.Cleanup(mgrA.Shutdown) + t.Cleanup(mgrB.Shutdown) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + chA := mgrA.SubscribeEvents(ctx) + chB := mgrB.SubscribeEvents(ctx) + + go mgrA.PublishStates([]*SubagentState{{Name: "from-a"}}) + + select { + case ev := <-chA: + require.Equal(t, "from-a", ev.Payload.States[0].Name) + case <-time.After(2 * time.Second): + t.Fatal("workspace A never received its own event") + } + + select { + case ev := <-chB: + t.Fatalf("workspace B received workspace A's event: %v", ev) + case <-time.After(100 * time.Millisecond): + // Expected — B's stream is isolated. + } +} + +// Compile-time assertion: SubscribeEvents must return the correct channel type. +var _ <-chan pubsub.Event[Event] = (*Manager)(nil).SubscribeEvents(context.Background()) + +func TestDiscoverFromConfig(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(tmp, "my-agent.md"), + []byte("---\nname: my-agent\ndescription: Does the thing.\n---\n\nYou are a specialist agent.\n"), + 0o644, + )) + + all, active, states := DiscoverFromConfig(DiscoveryConfig{ + SubagentsPaths: []string{tmp}, + }) + + require.NotEmpty(t, all) + require.NotEmpty(t, active) + + var found *Subagent + for _, a := range all { + if a.Name == "my-agent" { + found = a + break + } + } + require.NotNil(t, found, "my-agent must appear in all") + require.NotEmpty(t, found.Body, "DiscoverFromConfig must return Subagent.Body") + + inActive := false + for _, a := range active { + if a.Name == "my-agent" { + inActive = true + break + } + } + require.True(t, inActive, "my-agent must appear in active") + + foundState := false + for _, s := range states { + if s.Name == "my-agent" { + foundState = true + require.Equal(t, StateNormal, s.State) + } + } + require.True(t, foundState, "states must include my-agent with StateNormal") +} + +func TestDiscoverFromConfig_RejectsUnknownModelViaResolver(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(tmp, "good-model.md"), + []byte("---\nname: good-model\ndescription: ok\nmodel: gpt-4o\n---\n\nBody.\n"), + 0o644, + )) + require.NoError(t, os.WriteFile( + filepath.Join(tmp, "bad-model.md"), + []byte("---\nname: bad-model\ndescription: bad\nmodel: imaginary-99\n---\n\nBody.\n"), + 0o644, + )) + require.NoError(t, os.WriteFile( + filepath.Join(tmp, "alias.md"), + []byte("---\nname: alias-model\ndescription: alias\nmodel: large\n---\n\nBody.\n"), + 0o644, + )) + + knownModels := map[string]bool{"gpt-4o": true} + all, active, states := DiscoverFromConfig(DiscoveryConfig{ + SubagentsPaths: []string{tmp}, + IsKnownModel: func(provider, id string) bool { return knownModels[id] }, + }) + + activeNames := make(map[string]bool, len(active)) + for _, a := range active { + activeNames[a.Name] = true + } + require.True(t, activeNames["good-model"], "good-model must be active") + require.True(t, activeNames["alias-model"], "alias-model (large) must be active") + require.False(t, activeNames["bad-model"], "bad-model must be dropped on validation failure") + + allNames := make(map[string]bool, len(all)) + for _, a := range all { + allNames[a.Name] = true + } + require.False(t, allNames["bad-model"], "bad-model must not appear in all (validation failed)") + + var badState *SubagentState + for _, s := range states { + if s.Name == "bad-model" { + badState = s + } + } + require.NotNil(t, badState, "states must include bad-model entry") + require.Equal(t, StateError, badState.State) + require.ErrorContains(t, badState.Err, "model") +} + +func TestDiscoverFromConfig_DisabledFiltered(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(tmp, "that-agent.md"), + []byte("---\nname: that-agent\ndescription: Should be disabled.\n---\n\nBody.\n"), + 0o644, + )) + + all, active, states := DiscoverFromConfig(DiscoveryConfig{ + SubagentsPaths: []string{tmp}, + DisabledSubagents: []string{"that-agent"}, + }) + + hasInAll := false + for _, a := range all { + if a.Name == "that-agent" { + hasInAll = true + } + } + require.True(t, hasInAll, "DisabledSubagents must not be removed from all") + + for _, a := range active { + require.NotEqual(t, "that-agent", a.Name, "DisabledSubagents must be removed from active") + } + + hasInStates := false + for _, s := range states { + if s.Name == "that-agent" { + hasInStates = true + } + } + require.True(t, hasInStates, "states must still include disabled agent") +} + +func TestDiscoverFromConfig_Resolver(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(tmp, "env-agent.md"), + []byte("---\nname: env-agent\ndescription: Env-resolved agent.\n---\n\nBody.\n"), + 0o644, + )) + + all, _, _ := DiscoverFromConfig(DiscoveryConfig{ + SubagentsPaths: []string{"$CUSTOM_AGENTS_DIR"}, + Resolver: func(s string) (string, error) { + if s == "$CUSTOM_AGENTS_DIR" { + return tmp, nil + } + return s, errors.New("unknown variable") + }, + }) + + found := false + for _, a := range all { + if a.Name == "env-agent" { + found = true + } + } + require.True(t, found, "DiscoverFromConfig must expand $VAR via Resolver") +} + +func TestDiscoverFromConfig_EmptyPaths(t *testing.T) { + t.Parallel() + + all, active, _ := DiscoverFromConfig(DiscoveryConfig{}) + + require.Empty(t, all) + require.Empty(t, active) +} + +func TestManager_Reload(t *testing.T) { + t.Parallel() + + initial := []*Subagent{{Name: "old-all"}} + initialActive := []*Subagent{{Name: "old-active"}} + mgr := NewManager(initial, initialActive, nil) + t.Cleanup(mgr.Shutdown) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + ch := mgr.SubscribeEvents(ctx) + + newAll := []*Subagent{{Name: "new-all-a"}, {Name: "new-all-b"}} + newActive := []*Subagent{{Name: "new-active"}} + newStates := []*SubagentState{{Name: "new-state"}} + + mgr.Reload(newAll, newActive, newStates) + + // AllSubagents reflects the new slice. + gotAll := mgr.AllSubagents() + require.Len(t, gotAll, 2) + require.Equal(t, "new-all-a", gotAll[0].Name) + require.Equal(t, "new-all-b", gotAll[1].Name) + + // ActiveSubagents reflects the new slice. + gotActive := mgr.ActiveSubagents() + require.Len(t, gotActive, 1) + require.Equal(t, "new-active", gotActive[0].Name) + + // States reflects the new slice. + gotStates := mgr.States() + require.Len(t, gotStates, 1) + require.Equal(t, "new-state", gotStates[0].Name) + + // An event must be published to subscribers. + select { + case ev := <-ch: + require.Equal(t, pubsub.UpdatedEvent, ev.Type) + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for Reload event") + } +} diff --git a/internal/subagents/runtime.go b/internal/subagents/runtime.go new file mode 100644 index 0000000000..e0a85f7398 --- /dev/null +++ b/internal/subagents/runtime.go @@ -0,0 +1,184 @@ +package subagents + +import ( + "context" + "sync" + "time" + + "github.com/charmbracelet/crush/internal/pubsub" +) + +// RunningEntry holds the live state of a single running sub-agent. +type RunningEntry struct { + ChildSessionID string + ParentSessionID string + Name string + Color string + Model string + Status string + StartedAt time.Time +} + +// Live and terminal statuses for sub-agent runs. +const ( + StatusRunning = "running" + StatusRetrying = "retrying" + StatusCompleted = "completed" + StatusCancelled = "cancelled" + StatusFailed = "failed" +) + +// RuntimeEvent is published whenever the set of running sub-agents changes. +// Finished is non-nil when the event reflects a sub-agent that just finished, +// carrying its final entry (including terminal Status) so the UI can react. +type RuntimeEvent struct { + ParentSessionID string + Entries []RunningEntry + Finished *RunningEntry +} + +// Runtime tracks which sub-agents are currently running across all sessions. +// There is exactly one Runtime per workspace; it is safe for concurrent use. +type Runtime struct { + mu sync.RWMutex + entries map[string]RunningEntry // keyed by childSessionID + broker *pubsub.Broker[RuntimeEvent] +} + +// NewRuntime constructs an empty Runtime ready for use. +func NewRuntime() *Runtime { + return &Runtime{ + entries: make(map[string]RunningEntry), + broker: pubsub.NewBroker[RuntimeEvent](), + } +} + +// Register records a new running sub-agent and publishes a RuntimeEvent. +// It is a no-op when r is nil. +func (r *Runtime) Register(parentSessionID, childSessionID, name, color, model string) RunningEntry { + if r == nil { + return RunningEntry{} + } + entry := RunningEntry{ + ChildSessionID: childSessionID, + ParentSessionID: parentSessionID, + Name: name, + Color: color, + Model: model, + Status: StatusRunning, + StartedAt: time.Now(), + } + r.mu.Lock() + r.entries[childSessionID] = entry + r.mu.Unlock() + + r.publish(parentSessionID) + return entry +} + +// Finish removes a running sub-agent entry with a terminal status and publishes +// a RuntimeEvent whose Finished field carries the removed entry. Use one of the +// Status* constants for finalStatus. It is a no-op when r is nil. +func (r *Runtime) Finish(childSessionID, finalStatus string) { + if r == nil { + return + } + r.mu.Lock() + entry, ok := r.entries[childSessionID] + if ok { + entry.Status = finalStatus + delete(r.entries, childSessionID) + } + r.mu.Unlock() + + if !ok { + return + } + + r.broker.Publish(pubsub.UpdatedEvent, RuntimeEvent{ + ParentSessionID: entry.ParentSessionID, + Entries: r.entriesFor(entry.ParentSessionID), + Finished: &entry, + }) +} + +// SetStatus updates the Status field of a running sub-agent and publishes a +// RuntimeEvent. It is a no-op when r is nil or the entry is not found. +func (r *Runtime) SetStatus(childSessionID, status string) { + if r == nil { + return + } + r.mu.Lock() + entry, ok := r.entries[childSessionID] + if ok { + entry.Status = status + r.entries[childSessionID] = entry + } + r.mu.Unlock() + + if ok { + r.publish(entry.ParentSessionID) + } +} + +// List returns a snapshot of all running entries belonging to parentSessionID. +// The returned slice is a copy; mutating it does not affect internal state. +// Returns nil when r is nil or no entries match. +func (r *Runtime) List(parentSessionID string) []RunningEntry { + if r == nil { + return nil + } + r.mu.RLock() + defer r.mu.RUnlock() + + var out []RunningEntry + for _, e := range r.entries { + if e.ParentSessionID == parentSessionID { + out = append(out, e) + } + } + return out +} + +// Subscribe returns a channel that receives RuntimeEvents whenever the set of +// running sub-agents changes. The channel is closed when ctx is cancelled or +// Shutdown is called. Returns a closed channel when r is nil. +func (r *Runtime) Subscribe(ctx context.Context) <-chan pubsub.Event[RuntimeEvent] { + if r == nil { + ch := make(chan pubsub.Event[RuntimeEvent]) + close(ch) + return ch + } + return r.broker.Subscribe(ctx) +} + +// Shutdown releases broker resources. It is a no-op when r is nil. +func (r *Runtime) Shutdown() { + if r == nil { + return + } + r.broker.Shutdown() +} + +// publish gathers all entries for parentSessionID and sends a RuntimeEvent. +// Called with no locks held. +func (r *Runtime) publish(parentSessionID string) { + r.broker.Publish(pubsub.UpdatedEvent, RuntimeEvent{ + ParentSessionID: parentSessionID, + Entries: r.entriesFor(parentSessionID), + }) +} + +// entriesFor returns a snapshot of all entries belonging to parentSessionID. +// Acquires the read lock; callers must hold no locks. +func (r *Runtime) entriesFor(parentSessionID string) []RunningEntry { + r.mu.RLock() + defer r.mu.RUnlock() + var entries []RunningEntry + for _, e := range r.entries { + if e.ParentSessionID == parentSessionID { + entries = append(entries, e) + } + } + return entries +} diff --git a/internal/subagents/runtime_test.go b/internal/subagents/runtime_test.go new file mode 100644 index 0000000000..494eac7f26 --- /dev/null +++ b/internal/subagents/runtime_test.go @@ -0,0 +1,268 @@ +package subagents + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/stretchr/testify/require" +) + +func TestRuntime_Register(t *testing.T) { + t.Parallel() + + rt := NewRuntime() + t.Cleanup(rt.Shutdown) + + entry := rt.Register("parent-1", "child-1", "my-agent", "blue", "") + + require.Equal(t, "parent-1", entry.ParentSessionID) + require.Equal(t, "child-1", entry.ChildSessionID) + require.Equal(t, "my-agent", entry.Name) + require.Equal(t, "blue", entry.Color) + require.Equal(t, StatusRunning, entry.Status) + require.False(t, entry.StartedAt.IsZero(), "StartedAt must be set") + + entries := rt.List("parent-1") + require.Len(t, entries, 1) + require.Equal(t, entry, entries[0]) +} + +func TestRuntime_SetStatus(t *testing.T) { + t.Parallel() + + rt := NewRuntime() + t.Cleanup(rt.Shutdown) + + rt.Register("parent-1", "child-1", "my-agent", "green", "") + rt.SetStatus("child-1", "queued") + + entries := rt.List("parent-1") + require.Len(t, entries, 1) + require.Equal(t, "queued", entries[0].Status) +} + +func TestRuntime_List_IsolatedByParent(t *testing.T) { + t.Parallel() + + rt := NewRuntime() + t.Cleanup(rt.Shutdown) + + rt.Register("parent-A", "child-A", "agent-a", "cyan", "") + rt.Register("parent-B", "child-B", "agent-b", "magenta", "") + + entriesA := rt.List("parent-A") + require.Len(t, entriesA, 1) + require.Equal(t, "child-A", entriesA[0].ChildSessionID) + + entriesB := rt.List("parent-B") + require.Len(t, entriesB, 1) + require.Equal(t, "child-B", entriesB[0].ChildSessionID) +} + +func TestRuntime_List_ReturnsCopy(t *testing.T) { + t.Parallel() + + rt := NewRuntime() + t.Cleanup(rt.Shutdown) + + rt.Register("parent-1", "child-1", "my-agent", "yellow", "") + + first := rt.List("parent-1") + require.Len(t, first, 1) + + // Mutate the returned slice. + first[0] = RunningEntry{ChildSessionID: "mutated"} + first = append(first, RunningEntry{ChildSessionID: "extra"}) + + // Internal state must be unaffected. + second := rt.List("parent-1") + require.Len(t, second, 1) + require.Equal(t, "child-1", second[0].ChildSessionID) +} + +func TestRuntime_Subscribe_ReceivesRegisterEvent(t *testing.T) { + t.Parallel() + + rt := NewRuntime() + t.Cleanup(rt.Shutdown) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + ch := rt.Subscribe(ctx) + + rt.Register("parent-1", "child-1", "my-agent", "blue", "") + + select { + case ev := <-ch: + require.Equal(t, "parent-1", ev.Payload.ParentSessionID) + require.Len(t, ev.Payload.Entries, 1) + require.Equal(t, "child-1", ev.Payload.Entries[0].ChildSessionID) + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for register event") + } +} + +func TestRuntime_Finish_RemovesEntry(t *testing.T) { + t.Parallel() + + rt := NewRuntime() + t.Cleanup(rt.Shutdown) + + rt.Register("parent-1", "child-1", "my-agent", "blue", "") + rt.Finish("child-1", StatusCompleted) + + require.Empty(t, rt.List("parent-1")) +} + +func TestRuntime_Finish_PublishesFinishedEvent(t *testing.T) { + t.Parallel() + + rt := NewRuntime() + t.Cleanup(rt.Shutdown) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + rt.Register("parent-1", "child-1", "my-agent", "blue", "claude") + + ch := rt.Subscribe(ctx) + + rt.Finish("child-1", StatusCompleted) + + select { + case ev := <-ch: + require.NotNil(t, ev.Payload.Finished, "Finished must be set when sub-agent finishes") + require.Equal(t, "child-1", ev.Payload.Finished.ChildSessionID) + require.Equal(t, StatusCompleted, ev.Payload.Finished.Status) + require.Empty(t, ev.Payload.Entries) + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for finish event") + } +} + +func TestRuntime_Finish_StatusFlowsThrough(t *testing.T) { + t.Parallel() + + cases := []string{StatusCompleted, StatusCancelled, StatusFailed} + for _, status := range cases { + t.Run(status, func(t *testing.T) { + t.Parallel() + + rt := NewRuntime() + t.Cleanup(rt.Shutdown) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + rt.Register("parent-1", "child-1", "agent", "red", "") + ch := rt.Subscribe(ctx) + + rt.Finish("child-1", status) + + select { + case ev := <-ch: + require.NotNil(t, ev.Payload.Finished) + require.Equal(t, status, ev.Payload.Finished.Status) + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for finish event with status %q", status) + } + }) + } +} + +func TestRuntime_Finish_UnknownChildIsNoOp(t *testing.T) { + t.Parallel() + + rt := NewRuntime() + t.Cleanup(rt.Shutdown) + + require.NotPanics(t, func() { + rt.Finish("missing", StatusCompleted) + }) +} + +func TestRuntime_NilSafe(t *testing.T) { + t.Parallel() + + var rt *Runtime + + require.NotPanics(t, func() { + rt.Register("parent-1", "child-1", "agent", "red", "") + }) + require.NotPanics(t, func() { + rt.Finish("child-1", StatusCompleted) + }) + require.NotPanics(t, func() { + rt.SetStatus("child-1", "queued") + }) + require.NotPanics(t, func() { + entries := rt.List("parent-1") + require.Nil(t, entries) + }) + require.NotPanics(t, func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ch := rt.Subscribe(ctx) + // Channel must be closed (nil Runtime acts like a shut-down broker). + select { + case _, ok := <-ch: + require.False(t, ok, "Subscribe on nil Runtime must return a closed channel") + case <-time.After(100 * time.Millisecond): + t.Fatal("Subscribe on nil Runtime did not return a closed channel") + } + }) + require.NotPanics(t, func() { + rt.Shutdown() + }) +} + +func TestRuntime_Shutdown(t *testing.T) { + t.Parallel() + + rt := NewRuntime() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + ch := rt.Subscribe(ctx) + rt.Shutdown() + + select { + case _, ok := <-ch: + require.False(t, ok, "channel must be closed after Shutdown") + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for channel to close after Shutdown") + } +} + +func TestRuntime_ConcurrentAccess(t *testing.T) { + t.Parallel() + + rt := NewRuntime() + t.Cleanup(rt.Shutdown) + + const goroutines = 20 + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := range goroutines { + go func(i int) { + defer wg.Done() + childID := "child-" + string(rune('A'+i)) + rt.Register("parent-shared", childID, "agent", "white", "") + rt.List("parent-shared") + rt.SetStatus(childID, "queued") + rt.List("parent-shared") + rt.Finish(childID, StatusCompleted) + }(i) + } + + wg.Wait() +} + +// Compile-time assertion: Subscribe must return the correct channel type. +var _ <-chan pubsub.Event[RuntimeEvent] = (*Runtime)(nil).Subscribe(context.Background()) diff --git a/internal/subagents/subagents.go b/internal/subagents/subagents.go new file mode 100644 index 0000000000..975e1c4fcc --- /dev/null +++ b/internal/subagents/subagents.go @@ -0,0 +1,509 @@ +// Package subagents implements parsing and validation of subagent definition +// files. +package subagents + +import ( + "errors" + "fmt" + "log/slog" + "os" + "regexp" + "slices" + "strings" + "sync" + + "github.com/charlievieth/fastwalk" + "gopkg.in/yaml.v3" + + "github.com/charmbracelet/crush/internal/config" +) + +const ( + // MaxNameLength is the max characters allowed in a subagent name. + MaxNameLength = 64 + // MaxDescriptionLength is the max characters allowed in a subagent + // description. + MaxDescriptionLength = 1024 +) + +// namePattern matches valid subagent names: lowercase alphanumeric with single +// hyphens, no leading or trailing hyphens, no consecutive hyphens. +var namePattern = regexp.MustCompile(`^[a-z0-9]+(-[a-z0-9]+)*$`) + +// reservedNames is the set of names that may not be used for subagents. +var reservedNames = map[string]bool{ + "agent": true, + "task": true, + "coder": true, + "bash": true, + "view": true, + "edit": true, + "grep": true, + "glob": true, + "write": true, + "ls": true, + "mcp": true, +} + +// ToolList is a []string that YAML-unmarshals from either a comma-separated +// scalar string ("Read, Grep, Bash") or a YAML sequence (["Read","Grep"]). +// When the field is absent the value stays nil. +type ToolList []string + +// UnmarshalYAML implements yaml.Unmarshaler for ToolList. +func (t *ToolList) UnmarshalYAML(value *yaml.Node) error { + switch value.Kind { + case yaml.ScalarNode: + if value.Value == "" || value.Tag == "!!null" { + return nil + } + parts := strings.Split(value.Value, ",") + result := make([]string, 0, len(parts)) + for _, p := range parts { + if trimmed := strings.TrimSpace(p); trimmed != "" { + result = append(result, trimmed) + } + } + if len(result) > 0 { + *t = result + } + return nil + case yaml.SequenceNode: + var items []string + if err := value.Decode(&items); err != nil { + return err + } + if len(items) > 0 { + *t = items + } + return nil + default: + return nil + } +} + +// Subagent is a parsed subagent definition file. +type Subagent struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Tools ToolList `yaml:"tools"` + DisallowedTools ToolList `yaml:"disallowedTools"` + Model string `yaml:"model"` + Effort string `yaml:"effort"` + Skills []string `yaml:"skills"` + MCPServers []string `yaml:"mcp_servers"` + PermissionMode string `yaml:"permissionMode"` + Color string `yaml:"color"` + Provider string `yaml:"provider"` + Body string // set from markdown body after frontmatter + FilePath string // set from the file path passed to Parse +} + +// ResolvedColor returns the subagent's explicit Color if set, or falls back to +// AutoColor(Name) for a deterministic palette assignment. +func (s Subagent) ResolvedColor() string { + if s.Color != "" { + return s.Color + } + return AutoColor(s.Name) +} + +// PermissionMode values accepted in the PermissionMode field. +const ( + PermissionModeDefault = "default" + PermissionModeBypassPermissions = "bypassPermissions" +) + +// ToConfigAgent converts the Subagent into a config.Agent by applying the +// subagent's tool restrictions and model preference on top of the provided +// base agent configuration. +func (s *Subagent) ToConfigAgent(base config.Agent) config.Agent { + // Start with a copy of the base allowed tools — never mutate the original. + pool := append([]string(nil), base.AllowedTools...) + + // Apply disallowed tools first. + if len(s.DisallowedTools) > 0 { + disallowed := make(map[string]bool, len(s.DisallowedTools)) + for _, t := range s.DisallowedTools { + disallowed[t] = true + } + filtered := pool[:0:0] + for _, t := range pool { + if !disallowed[t] { + filtered = append(filtered, t) + } + } + pool = filtered + } + + // Intersect with the explicit tools allowlist (cannot widen beyond base). + if len(s.Tools) > 0 { + allowed := make(map[string]bool, len(s.Tools)) + for _, t := range s.Tools { + allowed[t] = true + } + filtered := pool[:0:0] + for _, t := range pool { + if allowed[t] { + filtered = append(filtered, t) + } + } + pool = filtered + } + + // Build AllowedMCP only when MCP servers are specified. + var allowedMCP map[string][]string + if len(s.MCPServers) > 0 { + allowedMCP = make(map[string][]string, len(s.MCPServers)) + for _, srv := range s.MCPServers { + allowedMCP[srv] = nil + } + } + + return config.Agent{ + ID: s.Name, + Name: s.Name, + Description: s.Description, + AllowedTools: pool, + AllowedMCP: allowedMCP, + // Model selection is driven by the coordinator from the raw `model:` + // value (alias or specific id); inherit the base type here. This field + // is no longer consumed for subagents. + Model: base.Model, + } +} + +// ParseContent parses a subagent definition from raw bytes. +func ParseContent(content []byte) (*Subagent, error) { + frontmatter, body, err := splitFrontmatter(string(content)) + if err != nil { + return nil, err + } + + var agent Subagent + if err := yaml.Unmarshal([]byte(frontmatter), &agent); err != nil { + return nil, fmt.Errorf("parsing frontmatter: %w", err) + } + + agent.Body = strings.TrimSpace(body) + + return &agent, nil +} + +// Parse reads a subagent definition file from disk and sets FilePath. +func Parse(path string) (*Subagent, error) { + content, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + agent, err := ParseContent(content) + if err != nil { + return nil, err + } + + agent.FilePath = path + + return agent, nil +} + +// ValidateAgainst runs Validate plus model-resolution checks. When isKnownModel +// is non-nil and Model is a non-empty value other than "large"/"small", the +// resolver must return true or validation fails. A nil resolver skips the +// model check (used when the caller has no config context). +func (s *Subagent) ValidateAgainst(isKnownModel func(provider, model string) bool) error { + err := s.Validate() + if isKnownModel == nil { + return err + } + if s.Model == "" || s.Model == "large" || s.Model == "small" { + return err + } + if !isKnownModel(s.Provider, s.Model) { + modelErr := fmt.Errorf("model %q is not a known model id; use \"large\", \"small\", or a valid provider model id", s.Model) + if err == nil { + return modelErr + } + return errors.Join(err, modelErr) + } + return err +} + +// Validate checks that the subagent meets all specification requirements. +// Multiple errors are joined with errors.Join. +func (s *Subagent) Validate() error { + var errs []error + + if s.Name == "" { + errs = append(errs, errors.New("name is required")) + } else { + if len(s.Name) > MaxNameLength { + errs = append(errs, fmt.Errorf("name exceeds %d characters", MaxNameLength)) + } + if !namePattern.MatchString(s.Name) { + errs = append(errs, errors.New("name must be lowercase alphanumeric with single hyphens (no leading, trailing, or consecutive hyphens)")) + } + if reservedNames[s.Name] { + errs = append(errs, fmt.Errorf("name %q is reserved", s.Name)) + } + } + + if s.Description == "" { + errs = append(errs, errors.New("description is required")) + } else if len(s.Description) > MaxDescriptionLength { + errs = append(errs, fmt.Errorf("description exceeds %d characters", MaxDescriptionLength)) + } + + if len(s.Tools) > 0 && len(s.DisallowedTools) > 0 { + disallowedSet := make(map[string]bool, len(s.DisallowedTools)) + for _, tool := range s.DisallowedTools { + disallowedSet[tool] = true + } + for _, tool := range s.Tools { + if disallowedSet[tool] { + errs = append(errs, fmt.Errorf("tool %q appears in both tools and disallowedTools", tool)) + } + } + } + + switch s.Effort { + case "", EffortNone, EffortMinimal, EffortLow, EffortMedium, EffortHigh, EffortXHigh, EffortMax: + default: + errs = append(errs, fmt.Errorf("effort %q is not valid; use one of: %q, %q, %q, %q, %q, %q, %q", s.Effort, EffortNone, EffortMinimal, EffortLow, EffortMedium, EffortHigh, EffortXHigh, EffortMax)) + } + + switch s.PermissionMode { + case "", PermissionModeDefault, PermissionModeBypassPermissions: + default: + errs = append(errs, fmt.Errorf("permissionMode %q is not valid; use %q or %q", s.PermissionMode, PermissionModeDefault, PermissionModeBypassPermissions)) + } + + if s.Color != "" && !IsValidColor(s.Color) { + errs = append(errs, fmt.Errorf("color %q is not valid; use one of: red, orange, yellow, green, cyan, blue, purple, pink", s.Color)) + } + + if s.Provider != "" && (s.Model == "" || s.Model == "large" || s.Model == "small") { + errs = append(errs, fmt.Errorf("provider requires a specific model id; use a valid provider model id (not empty, %q, or %q)", "large", "small")) + } + + return errors.Join(errs...) +} + +// Filter removes subagents whose names appear in the disabled list. +func Filter(all []*Subagent, disabled []string) []*Subagent { + if len(disabled) == 0 { + return all + } + + disabledSet := make(map[string]bool, len(disabled)) + for _, name := range disabled { + disabledSet[name] = true + } + + result := make([]*Subagent, 0, len(all)) + for _, s := range all { + if !disabledSet[s.Name] { + result = append(result, s) + } + } + return result +} + +// Deduplicate removes duplicate subagents by name. When duplicates exist, the +// last occurrence wins. +func Deduplicate(all []*Subagent) []*Subagent { + if len(all) == 0 { + return nil + } + + seen := make(map[string]int, len(all)) + for i, s := range all { + seen[s.Name] = i + } + + result := make([]*Subagent, 0, len(seen)) + for i, s := range all { + if seen[s.Name] == i { + result = append(result, s) + } + } + return result +} + +// DiscoveryState represents the outcome of discovering a single subagent file. +type DiscoveryState int + +const ( + // StateNormal indicates the subagent was parsed and validated successfully. + StateNormal DiscoveryState = iota + // StateError indicates discovery encountered a scan/parse/validate error. + StateError +) + +// SubagentState represents the latest discovery status of a subagent file. +type SubagentState struct { + Name string + Path string + State DiscoveryState + Err error +} + +// Event is published when subagent discovery completes. +type Event struct { + States []*SubagentState +} + +// cloneSubagents returns a shallow copy of the slice so callers cannot mutate +// the manager's internal slice header. The underlying *Subagent pointers are +// shared — subagents are immutable post-discovery. +func cloneSubagents(in []*Subagent) []*Subagent { + if in == nil { + return nil + } + out := make([]*Subagent, len(in)) + copy(out, in) + return out +} + +// cloneStates returns a deep copy of the given state slice so callers cannot +// accidentally mutate the source. +func cloneStates(states []*SubagentState) []*SubagentState { + if states == nil { + return nil + } + result := make([]*SubagentState, len(states)) + for i, s := range states { + clone := *s + result[i] = &clone + } + return result +} + +// DeduplicateStates removes duplicate subagent states by name. When duplicates +// exist, the last occurrence wins (consistent with Deduplicate for subagents). +func DeduplicateStates(all []*SubagentState) []*SubagentState { + seen := make(map[string]int, len(all)) + for i, s := range all { + if s.Name != "" { + seen[s.Name] = i + } + } + + result := make([]*SubagentState, 0, len(seen)) + for i, s := range all { + // Keep the last occurrence of this name, or anything without a + // name (error state). + if s.Name == "" || seen[s.Name] == i { + result = append(result, s) + } + } + return result +} + +// DiscoverWithStates finds all valid subagent definition files (*.md) in the +// given paths recursively, and returns both the discovered subagents and a +// per-file state slice describing parse/validation outcomes. When +// isKnownModel is non-nil it is used to validate non-alias model ids; nil +// skips that check. +func DiscoverWithStates(paths []string, isKnownModel func(provider, model string) bool) ([]*Subagent, []*SubagentState) { + var agents []*Subagent + var states []*SubagentState + var mu sync.Mutex + seen := make(map[string]bool) + + addState := func(name, path string, state DiscoveryState, err error) { + mu.Lock() + states = append(states, &SubagentState{ + Name: name, + Path: path, + State: state, + Err: err, + }) + mu.Unlock() + } + + for _, base := range paths { + conf := fastwalk.Config{ + Follow: true, + ToSlash: fastwalk.DefaultToSlash(), + } + err := fastwalk.Walk(&conf, base, func(path string, d os.DirEntry, err error) error { + if err != nil { + slog.Warn("Failed to walk subagents path entry", "base", base, "path", path, "error", err) + addState("", path, StateError, err) + return nil + } + if d.IsDir() || !strings.HasSuffix(d.Name(), ".md") { + return nil + } + mu.Lock() + if seen[path] { + mu.Unlock() + return nil + } + seen[path] = true + mu.Unlock() + + agent, err := Parse(path) + if err != nil { + slog.Warn("Failed to parse subagent file", "path", path, "error", err) + addState("", path, StateError, err) + return nil + } + if err := agent.ValidateAgainst(isKnownModel); err != nil { + slog.Warn("Subagent validation failed", "path", path, "error", err) + addState(agent.Name, path, StateError, err) + return nil + } + slog.Debug("Successfully loaded subagent", "name", agent.Name, "path", path) + mu.Lock() + agents = append(agents, agent) + mu.Unlock() + addState(agent.Name, path, StateNormal, nil) + return nil + }) + if err != nil && !os.IsNotExist(err) { + slog.Warn("Failed to walk subagents path", "path", base, "error", err) + } + } + + // fastwalk traversal order is non-deterministic, so sort for stable output. + // Sort by filepath first, then alphabetically by name within each path. + slices.SortStableFunc(agents, func(a, b *Subagent) int { + if c := strings.Compare(strings.ToLower(a.FilePath), strings.ToLower(b.FilePath)); c != 0 { + return c + } + return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name)) + }) + + return agents, states +} + +// splitFrontmatter extracts YAML frontmatter and body from markdown content. +func splitFrontmatter(content string) (frontmatter, body string, err error) { + // Strip UTF-8 BOM for compatibility with editors that include it. + content = strings.TrimPrefix(content, "\ufeff") + // Normalize line endings to \n for consistent parsing. + content = strings.ReplaceAll(content, "\r\n", "\n") + content = strings.ReplaceAll(content, "\r", "\n") + + lines := strings.Split(content, "\n") + start := slices.IndexFunc(lines, func(line string) bool { + return strings.TrimSpace(line) != "" + }) + if start == -1 || strings.TrimSpace(lines[start]) != "---" { + return "", "", errors.New("no YAML frontmatter found") + } + + endOffset := slices.IndexFunc(lines[start+1:], func(line string) bool { + return strings.TrimSpace(line) == "---" + }) + if endOffset == -1 { + return "", "", errors.New("unclosed frontmatter") + } + end := start + 1 + endOffset + + frontmatter = strings.Join(lines[start+1:end], "\n") + body = strings.Join(lines[end+1:], "\n") + return frontmatter, body, nil +} diff --git a/internal/subagents/subagents_test.go b/internal/subagents/subagents_test.go new file mode 100644 index 0000000000..e7f4267b31 --- /dev/null +++ b/internal/subagents/subagents_test.go @@ -0,0 +1,1018 @@ +package subagents + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseContent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + wantTools []string + wantDisallowed []string + wantName string + wantDescription string + wantModel string + wantSkills []string + wantMCPServers []string + wantPermMode string + wantBody string + wantErr bool + }{ + { + name: "comma_separated_tools", + content: `--- +name: my-agent +description: A test agent. +tools: Read, Grep, Bash +--- +`, + wantName: "my-agent", + wantDescription: "A test agent.", + wantTools: []string{"Read", "Grep", "Bash"}, + }, + { + name: "yaml_array_tools", + content: `--- +name: my-agent +description: A test agent. +tools: + - Read + - Grep +--- +`, + wantName: "my-agent", + wantDescription: "A test agent.", + wantTools: []string{"Read", "Grep"}, + }, + { + name: "no_tools_field", + content: `--- +name: my-agent +description: A test agent. +--- +`, + wantName: "my-agent", + wantDescription: "A test agent.", + wantTools: nil, + }, + { + name: "disallowed_tools_comma", + content: `--- +name: my-agent +description: A test agent. +disallowedTools: Write, Edit +--- +`, + wantName: "my-agent", + wantDescription: "A test agent.", + wantDisallowed: []string{"Write", "Edit"}, + }, + { + name: "all_fields", + content: `--- +name: my-agent +description: A fully specified agent. +model: large +tools: + - Read + - Bash +disallowedTools: Write, Edit +skills: + - pdf-processing + - data-analysis +mcp_servers: + - filesystem +--- + +This is the system prompt body. +`, + wantName: "my-agent", + wantDescription: "A fully specified agent.", + wantModel: "large", + wantTools: []string{"Read", "Bash"}, + wantDisallowed: []string{"Write", "Edit"}, + wantSkills: []string{"pdf-processing", "data-analysis"}, + wantMCPServers: []string{"filesystem"}, + wantBody: "This is the system prompt body.", + }, + { + name: "permission_mode_bypass_decoded", + content: `--- +name: bypass-agent +description: An agent with bypass permissions. +permissionMode: bypassPermissions +--- + +Body. +`, + wantName: "bypass-agent", + wantDescription: "An agent with bypass permissions.", + wantPermMode: PermissionModeBypassPermissions, + wantBody: "Body.", + }, + { + name: "body_extracted", + content: `--- +name: my-agent +description: A test agent. +--- + +# System Prompt + +Do the thing. +`, + wantName: "my-agent", + wantDescription: "A test agent.", + wantBody: "# System Prompt\n\nDo the thing.", + }, + { + name: "utf8_bom_stripped", + content: "\uFEFF---\n" + + "name: bom-agent\n" + + "description: Agent with BOM.\n" + + "---\n\n" + + "Body here.\n", + wantName: "bom-agent", + wantDescription: "Agent with BOM.", + wantBody: "Body here.", + }, + { + name: "leading_blank_lines", + content: "\n\n---\n" + + "name: blank-prefix\n" + + "description: Agent with leading blank lines.\n" + + "---\n\n" + + "Body here.\n", + wantName: "blank-prefix", + wantDescription: "Agent with leading blank lines.", + wantBody: "Body here.", + }, + { + name: "no_frontmatter", + content: "# Just Markdown\n\nNo frontmatter here.", + wantErr: true, + }, + { + name: "unclosed_frontmatter", + content: `--- +name: my-agent +description: Never closed. +`, + wantErr: true, + }, + { + name: "empty_content", + content: "", + wantErr: true, + }, + { + name: "only_bom", + content: "\ufeff", + wantErr: true, + }, + { + name: "only_whitespace", + content: " \n\n \t\n", + wantErr: true, + }, + { + name: "empty_frontmatter_no_body", + content: "---\n---\n", + wantName: "", + wantDescription: "", + }, + { + name: "crlf_line_endings", + content: "---\r\nname: crlf-agent\r\n" + + "description: Uses CRLF endings.\r\n---\r\n\r\nBody.\r\n", + wantName: "crlf-agent", + wantDescription: "Uses CRLF endings.", + wantBody: "Body.", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + agent, err := ParseContent([]byte(tt.content)) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.NotNil(t, agent) + + require.Equal(t, tt.wantName, agent.Name) + require.Equal(t, tt.wantDescription, agent.Description) + require.Equal(t, tt.wantTools, []string(agent.Tools)) + require.Equal(t, tt.wantDisallowed, []string(agent.DisallowedTools)) + + if tt.wantModel != "" { + require.Equal(t, tt.wantModel, agent.Model) + } + if tt.wantSkills != nil { + require.Equal(t, tt.wantSkills, agent.Skills) + } + if tt.wantMCPServers != nil { + require.Equal(t, tt.wantMCPServers, agent.MCPServers) + } + require.Equal(t, tt.wantPermMode, agent.PermissionMode) + if tt.wantBody != "" { + require.Equal(t, tt.wantBody, agent.Body) + } + }) + } +} + +func TestParse(t *testing.T) { + t.Parallel() + + t.Run("reads file and sets filepath", func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "my-agent.md") + require.NoError(t, os.WriteFile(path, []byte(`--- +name: my-agent +description: A test agent. +--- + +Body here. +`), 0o644)) + + agent, err := Parse(path) + require.NoError(t, err) + require.Equal(t, "my-agent", agent.Name) + require.Equal(t, "A test agent.", agent.Description) + require.Equal(t, "Body here.", agent.Body) + require.Equal(t, path, agent.FilePath) + }) + + t.Run("missing file returns error", func(t *testing.T) { + t.Parallel() + + _, err := Parse(filepath.Join(t.TempDir(), "nonexistent.md")) + require.Error(t, err) + }) +} + +func TestValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + agent Subagent + wantErr bool + errMsg string + }{ + { + name: "valid_minimal", + agent: Subagent{Name: "my-agent", Description: "Does something."}, + }, + { + name: "missing_name", + agent: Subagent{Description: "Something."}, + wantErr: true, + errMsg: "name is required", + }, + { + name: "missing_description", + agent: Subagent{Name: "my-agent"}, + wantErr: true, + errMsg: "description is required", + }, + { + name: "uppercase_in_name", + agent: Subagent{Name: "MyAgent", Description: "Something."}, + wantErr: true, + errMsg: "lowercase", + }, + { + name: "name_too_long", + agent: Subagent{Name: strings.Repeat("a", 65), Description: "Something."}, + wantErr: true, + errMsg: "exceeds", + }, + { + name: "reserved_name_agent", + agent: Subagent{Name: "agent", Description: "Something."}, + wantErr: true, + errMsg: "reserved", + }, + { + name: "reserved_name_task", + agent: Subagent{Name: "task", Description: "Something."}, + wantErr: true, + errMsg: "reserved", + }, + { + name: "reserved_name_bash", + agent: Subagent{Name: "bash", Description: "Something."}, + wantErr: true, + errMsg: "reserved", + }, + { + name: "tools_disallowed_overlap", + agent: Subagent{ + Name: "my-agent", + Description: "Something.", + Tools: ToolList{"bash", "grep"}, + DisallowedTools: ToolList{"bash"}, + }, + wantErr: true, + errMsg: "both", + }, + { + name: "description_too_long", + agent: Subagent{Name: "my-agent", Description: strings.Repeat("a", 1025)}, + wantErr: true, + errMsg: "description", + }, + { + name: "starts_with_hyphen", + agent: Subagent{Name: "-my-agent", Description: "Something."}, + wantErr: true, + errMsg: "lowercase", + }, + { + name: "consecutive_hyphens", + agent: Subagent{Name: "my--agent", Description: "Something."}, + wantErr: true, + errMsg: "lowercase", + }, + { + name: "permission_mode_default_valid", + agent: Subagent{Name: "my-agent", Description: "Something.", PermissionMode: PermissionModeDefault}, + wantErr: false, + }, + { + name: "permission_mode_bypass_valid", + agent: Subagent{Name: "my-agent", Description: "Something.", PermissionMode: PermissionModeBypassPermissions}, + wantErr: false, + }, + { + name: "permission_mode_accept_edits_rejected", + agent: Subagent{Name: "my-agent", Description: "Something.", PermissionMode: "acceptEdits"}, + wantErr: true, + errMsg: "permissionMode", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.agent.Validate() + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateAgainst(t *testing.T) { + t.Parallel() + + knownModels := map[string]bool{"gpt-4o": true, "claude-opus-4-7": true} + isKnown := func(provider, id string) bool { return knownModels[id] } + + tests := []struct { + name string + agent Subagent + wantErr bool + errMsg string + }{ + { + name: "model_empty_ok", + agent: Subagent{Name: "a", Description: "d", Model: ""}, + }, + { + name: "model_large_ok", + agent: Subagent{Name: "a", Description: "d", Model: "large"}, + }, + { + name: "model_small_ok", + agent: Subagent{Name: "a", Description: "d", Model: "small"}, + }, + { + name: "known_model_id_ok", + agent: Subagent{Name: "a", Description: "d", Model: "gpt-4o"}, + }, + { + name: "unknown_model_rejected", + agent: Subagent{Name: "a", Description: "d", Model: "imaginary-99"}, + wantErr: true, + errMsg: "model", + }, + { + name: "still_runs_base_validation", + agent: Subagent{Name: "", Description: "d", Model: "large"}, + wantErr: true, + errMsg: "name is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.agent.ValidateAgainst(isKnown) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateAgainst_NilResolver_AcceptsAnyNonEmptyModel(t *testing.T) { + t.Parallel() + + // Without a resolver, model id strings cannot be validated; ValidateAgainst + // should accept any non-empty model string and defer enforcement. + s := Subagent{Name: "a", Description: "d", Model: "gpt-99-future"} + require.NoError(t, s.ValidateAgainst(nil)) +} + +// TestValidateAgainst_ProviderPropagated verifies that ValidateAgainst forwards +// the subagent's Provider field as the first argument to the resolver. +func TestValidateAgainst_ProviderPropagated(t *testing.T) { + t.Parallel() + + var capturedProvider, capturedModel string + isKnown := func(provider, model string) bool { + capturedProvider = provider + capturedModel = model + return true + } + + s := Subagent{Name: "a", Description: "d", Provider: "openai", Model: "gpt-4o"} + require.NoError(t, s.ValidateAgainst(isKnown)) + require.Equal(t, "openai", capturedProvider) + require.Equal(t, "gpt-4o", capturedModel) +} + +// TestValidateAgainst_EmptyProviderPropagated verifies that when Provider is +// empty, ValidateAgainst calls the resolver with an empty provider string +// (allowing callers to perform an all-provider scan). +func TestValidateAgainst_EmptyProviderPropagated(t *testing.T) { + t.Parallel() + + var capturedProvider string + isKnown := func(provider, model string) bool { + capturedProvider = provider + return true + } + + s := Subagent{Name: "a", Description: "d", Provider: "", Model: "gpt-4o"} + require.NoError(t, s.ValidateAgainst(isKnown)) + require.Equal(t, "", capturedProvider) +} + +// TestParseContent_ProviderField verifies that the provider field round-trips +// through YAML frontmatter parsing. +func TestParseContent_ProviderField(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + wantProvider string + }{ + { + name: "provider_set", + content: `--- +name: my-agent +description: A test agent. +provider: openai +model: gpt-4o +--- +`, + wantProvider: "openai", + }, + { + name: "provider_absent_is_empty", + content: `--- +name: my-agent +description: A test agent. +--- +`, + wantProvider: "", + }, + { + name: "provider_explicit_empty", + content: `--- +name: my-agent +description: A test agent. +provider: "" +--- +`, + wantProvider: "", + }, + { + name: "provider_anthropic", + content: `--- +name: my-agent +description: A test agent. +provider: anthropic +model: claude-opus-4-7 +--- +`, + wantProvider: "anthropic", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + agent, err := ParseContent([]byte(tt.content)) + require.NoError(t, err) + require.Equal(t, tt.wantProvider, agent.Provider) + }) + } +} + +// TestValidate_ProviderRequiresSpecificModel verifies that when provider is set, +// model must be a specific model ID (not empty, "large", or "small"). +func TestValidate_ProviderRequiresSpecificModel(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + agent Subagent + wantErr bool + errMsg string + }{ + { + name: "provider_set_no_model", + agent: Subagent{Name: "my-agent", Description: "Does something.", Provider: "openai"}, + wantErr: true, + errMsg: "model", + }, + { + name: "provider_set_model_large", + agent: Subagent{Name: "my-agent", Description: "Does something.", Provider: "openai", Model: "large"}, + wantErr: true, + errMsg: "model", + }, + { + name: "provider_set_model_small", + agent: Subagent{Name: "my-agent", Description: "Does something.", Provider: "openai", Model: "small"}, + wantErr: true, + errMsg: "model", + }, + { + name: "provider_set_specific_model_ok", + agent: Subagent{Name: "my-agent", Description: "Does something.", Provider: "openai", Model: "gpt-4o"}, + }, + { + name: "no_provider_model_large_ok", + agent: Subagent{Name: "my-agent", Description: "Does something.", Provider: "", Model: "large"}, + }, + { + name: "no_provider_no_model_ok", + agent: Subagent{Name: "my-agent", Description: "Does something.", Provider: "", Model: ""}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.agent.Validate() + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestFilter(t *testing.T) { + t.Parallel() + + all := []*Subagent{ + {Name: "a"}, + {Name: "b"}, + {Name: "c"}, + } + + tests := []struct { + name string + disabled []string + wantLen int + }{ + {"nil_disabled", nil, 3}, + {"filter_one", []string{"b"}, 2}, + {"filter_all", []string{"a", "b", "c"}, 0}, + {"filter_nonexistent", []string{"z"}, 3}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := Filter(all, tt.disabled) + require.Len(t, result, tt.wantLen) + }) + } +} + +func TestDeduplicate(t *testing.T) { + t.Parallel() + + t.Run("no_duplicates", func(t *testing.T) { + t.Parallel() + + input := []*Subagent{{Name: "a", FilePath: "/a"}, {Name: "b", FilePath: "/b"}} + result := Deduplicate(input) + require.Len(t, result, 2) + }) + + t.Run("last_wins", func(t *testing.T) { + t.Parallel() + + input := []*Subagent{ + {Name: "a", FilePath: "/a"}, + {Name: "a", FilePath: "/b"}, + } + result := Deduplicate(input) + require.Len(t, result, 1) + require.Equal(t, "/b", result[0].FilePath) + }) + + t.Run("empty_input", func(t *testing.T) { + t.Parallel() + + result := Deduplicate(nil) + require.Empty(t, result) + }) +} + +// TestParseContent_ColorField verifies that the color field round-trips through +// YAML frontmatter parsing for all defined color values plus absent/empty. +func TestParseContent_ColorField(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + wantColor string + }{ + { + name: "color_red", + content: `--- +name: my-agent +description: A test agent. +color: red +--- +`, + wantColor: "red", + }, + { + name: "color_orange", + content: `--- +name: my-agent +description: A test agent. +color: orange +--- +`, + wantColor: "orange", + }, + { + name: "color_yellow", + content: `--- +name: my-agent +description: A test agent. +color: yellow +--- +`, + wantColor: "yellow", + }, + { + name: "color_green", + content: `--- +name: my-agent +description: A test agent. +color: green +--- +`, + wantColor: "green", + }, + { + name: "color_cyan", + content: `--- +name: my-agent +description: A test agent. +color: cyan +--- +`, + wantColor: "cyan", + }, + { + name: "color_blue", + content: `--- +name: my-agent +description: A test agent. +color: blue +--- +`, + wantColor: "blue", + }, + { + name: "color_purple", + content: `--- +name: my-agent +description: A test agent. +color: purple +--- +`, + wantColor: "purple", + }, + { + name: "color_pink", + content: `--- +name: my-agent +description: A test agent. +color: pink +--- +`, + wantColor: "pink", + }, + { + name: "color_absent_is_empty", + content: `--- +name: my-agent +description: A test agent. +--- +`, + wantColor: "", + }, + { + name: "color_explicit_empty_string", + content: `--- +name: my-agent +description: A test agent. +color: "" +--- +`, + wantColor: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + agent, err := ParseContent([]byte(tt.content)) + require.NoError(t, err) + require.Equal(t, tt.wantColor, agent.Color) + }) + } +} + +// TestValidate_ColorField verifies that Validate accepts all eight defined +// color constants and empty, and rejects everything else with an error +// mentioning "color". +func TestValidate_ColorField(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + color string + wantErr bool + }{ + {name: "empty_accepted", color: ""}, + {name: "red_accepted", color: ColorRed}, + {name: "orange_accepted", color: ColorOrange}, + {name: "yellow_accepted", color: ColorYellow}, + {name: "green_accepted", color: ColorGreen}, + {name: "cyan_accepted", color: ColorCyan}, + {name: "blue_accepted", color: ColorBlue}, + {name: "purple_accepted", color: ColorPurple}, + {name: "pink_accepted", color: ColorPink}, + {name: "ultra_rejected", color: "ultra", wantErr: true}, + {name: "RED_rejected_case_sensitive", color: "RED", wantErr: true}, + {name: "lime_rejected", color: "lime", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + s := Subagent{ + Name: "test-agent", + Description: "Does something.", + Color: tt.color, + } + err := s.Validate() + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "color") + } else { + require.NoError(t, err) + } + }) + } +} + +// TestResolvedColor_ExplicitColor verifies that ResolvedColor returns the +// explicitly set Color value when it is non-empty. +func TestResolvedColor_ExplicitColor(t *testing.T) { + t.Parallel() + + s := Subagent{ + Name: "my-agent", + Description: "Does something.", + Color: ColorBlue, + } + require.Equal(t, ColorBlue, s.ResolvedColor()) +} + +// TestResolvedColor_AutoFallback verifies that when Color is empty, +// ResolvedColor returns AutoColor(Name): a non-empty string that is one of the +// eight valid color names. +func TestResolvedColor_AutoFallback(t *testing.T) { + t.Parallel() + + s := Subagent{ + Name: "my-agent", + Description: "Does something.", + Color: "", + } + + result := s.ResolvedColor() + require.NotEmpty(t, result, "ResolvedColor must not return empty when Color is unset") + require.True(t, IsValidColor(result), "ResolvedColor fallback %q must be a valid color", result) + require.Equal(t, AutoColor(s.Name), result, "ResolvedColor must equal AutoColor(Name) when Color is empty") +} + +func TestDiscoverWithStates(t *testing.T) { + t.Parallel() + + t.Run("discovers_valid_agents_recursively", func(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + subdir := filepath.Join(tmp, "subdir") + require.NoError(t, os.MkdirAll(subdir, 0o755)) + + require.NoError(t, os.WriteFile( + filepath.Join(tmp, "top-agent.md"), + []byte("---\nname: top-agent\ndescription: Top level agent.\n---\n\nYou are a specialist agent.\n"), + 0o644, + )) + require.NoError(t, os.WriteFile( + filepath.Join(subdir, "sub-agent.md"), + []byte("---\nname: sub-agent\ndescription: Nested agent.\n---\n\nYou are a nested specialist agent.\n"), + 0o644, + )) + + agents, states := DiscoverWithStates([]string{tmp}, nil) + + require.Len(t, agents, 2) + names := make([]string, 0, len(agents)) + for _, a := range agents { + names = append(names, a.Name) + } + require.Contains(t, names, "top-agent") + require.Contains(t, names, "sub-agent") + require.Len(t, states, 2) + }) + + t.Run("invalid_agent_no_frontmatter_appears_as_error_not_in_agents", func(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(tmp, "bad-agent.md"), + []byte("# No frontmatter here\n\nJust markdown.\n"), + 0o644, + )) + + agents, states := DiscoverWithStates([]string{tmp}, nil) + + require.Empty(t, agents) + require.Len(t, states, 1) + require.Equal(t, StateError, states[0].State) + require.Error(t, states[0].Err) + }) + + t.Run("nonexistent_path_silently_skipped", func(t *testing.T) { + t.Parallel() + + agents, states := DiscoverWithStates([]string{filepath.Join(t.TempDir(), "does-not-exist")}, nil) + + require.Empty(t, agents) + require.Empty(t, states) + }) + + t.Run("empty_dir_returns_no_results", func(t *testing.T) { + t.Parallel() + + agents, states := DiscoverWithStates([]string{t.TempDir()}, nil) + + require.Empty(t, agents) + require.Empty(t, states) + }) + + t.Run("non_md_files_ignored", func(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(tmp, "agent.txt"), + []byte("---\nname: txt-agent\ndescription: Should be ignored.\n---\n\nBody.\n"), + 0o644, + )) + + agents, states := DiscoverWithStates([]string{tmp}, nil) + + require.Empty(t, agents) + require.Empty(t, states) + }) + + t.Run("resolver_receives_provider_and_model", func(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(tmp, "specific-agent.md"), + []byte("---\nname: specific-agent\ndescription: Uses a specific model.\nprovider: openai\nmodel: gpt-4o\n---\n\nBody.\n"), + 0o644, + )) + + var capturedProvider, capturedModel string + isKnown := func(provider, model string) bool { + capturedProvider = provider + capturedModel = model + return true + } + + agents, states := DiscoverWithStates([]string{tmp}, isKnown) + + require.Len(t, agents, 1) + require.Len(t, states, 1) + require.Equal(t, StateNormal, states[0].State) + require.Equal(t, "openai", capturedProvider) + require.Equal(t, "gpt-4o", capturedModel) + }) + + t.Run("unknown_model_with_resolver_produces_error_state", func(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + require.NoError(t, os.WriteFile( + filepath.Join(tmp, "unknown-model-agent.md"), + []byte("---\nname: unknown-model-agent\ndescription: Uses an unknown model.\nmodel: no-such-model-99\n---\n\nBody.\n"), + 0o644, + )) + + isKnown := func(provider, model string) bool { return false } + + agents, states := DiscoverWithStates([]string{tmp}, isKnown) + + require.Empty(t, agents) + require.Len(t, states, 1) + require.Equal(t, StateError, states[0].State) + require.Error(t, states[0].Err) + }) +} + +// TestValidate_ReportsAllToolOverlaps verifies that when multiple tools appear +// in both Tools and DisallowedTools, Validate reports every overlapping tool +// rather than stopping at the first. The fix removed a break so all overlaps +// are joined via errors.Join. +func TestValidate_ReportsAllToolOverlaps(t *testing.T) { + t.Parallel() + + sa := Subagent{ + Name: "reviewer", + Description: "Reviews things.", + Tools: ToolList{"view", "edit", "bash"}, + DisallowedTools: ToolList{"view", "edit", "bash"}, + } + + err := sa.Validate() + require.Error(t, err) + require.ErrorContains(t, err, "view") + require.ErrorContains(t, err, "edit") + require.ErrorContains(t, err, "bash") +} diff --git a/internal/subagents/to_config_agent_test.go b/internal/subagents/to_config_agent_test.go new file mode 100644 index 0000000000..d01089f85c --- /dev/null +++ b/internal/subagents/to_config_agent_test.go @@ -0,0 +1,191 @@ +package subagents + +import ( + "testing" + + "github.com/charmbracelet/crush/internal/config" + "github.com/stretchr/testify/require" +) + +func TestToConfigAgent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + subagent Subagent + base config.Agent + check func(t *testing.T, result config.Agent) + }{ + { + name: "no_restrictions", + subagent: Subagent{Name: "my-agent", Description: "Does something."}, + base: config.Agent{ + AllowedTools: []string{"bash", "grep", "view"}, + Model: config.SelectedModelTypeLarge, + }, + check: func(t *testing.T, result config.Agent) { + t.Helper() + require.Equal(t, []string{"bash", "grep", "view"}, result.AllowedTools) + }, + }, + { + name: "tools_filter", + subagent: Subagent{ + Name: "my-agent", + Description: "Does something.", + Tools: ToolList{"grep", "view"}, + }, + base: config.Agent{ + AllowedTools: []string{"bash", "grep", "view", "edit"}, + Model: config.SelectedModelTypeLarge, + }, + check: func(t *testing.T, result config.Agent) { + t.Helper() + require.ElementsMatch(t, []string{"grep", "view"}, result.AllowedTools) + }, + }, + { + name: "disallowed_tools", + subagent: Subagent{ + Name: "my-agent", + Description: "Does something.", + DisallowedTools: ToolList{"view"}, + }, + base: config.Agent{ + AllowedTools: []string{"bash", "grep", "view"}, + Model: config.SelectedModelTypeLarge, + }, + check: func(t *testing.T, result config.Agent) { + t.Helper() + require.ElementsMatch(t, []string{"bash", "grep"}, result.AllowedTools) + }, + }, + { + name: "both_filters_disallowed_first", + subagent: Subagent{ + Name: "my-agent", + Description: "Does something.", + DisallowedTools: ToolList{"bash"}, + Tools: ToolList{"grep", "bash"}, + }, + base: config.Agent{ + AllowedTools: []string{"bash", "grep", "view"}, + Model: config.SelectedModelTypeLarge, + }, + check: func(t *testing.T, result config.Agent) { + t.Helper() + // disallowed removes "bash" first → base becomes ["grep","view"] + // then tools filter intersects with ["grep","bash"] → only "grep" survives + require.ElementsMatch(t, []string{"grep"}, result.AllowedTools) + }, + }, + { + name: "mcp_servers", + subagent: Subagent{ + Name: "my-agent", + Description: "Does something.", + MCPServers: []string{"github", "linear"}, + }, + base: config.Agent{ + Model: config.SelectedModelTypeLarge, + }, + check: func(t *testing.T, result config.Agent) { + t.Helper() + require.NotNil(t, result.AllowedMCP) + require.Contains(t, result.AllowedMCP, "github") + require.Contains(t, result.AllowedMCP, "linear") + // values should be nil (all tools from that MCP allowed) + require.Nil(t, result.AllowedMCP["github"]) + require.Nil(t, result.AllowedMCP["linear"]) + }, + }, + { + name: "mcp_servers_empty", + subagent: Subagent{ + Name: "my-agent", + Description: "Does something.", + MCPServers: nil, + }, + base: config.Agent{ + Model: config.SelectedModelTypeLarge, + }, + check: func(t *testing.T, result config.Agent) { + t.Helper() + require.Nil(t, result.AllowedMCP) + }, + }, + { + // The `model:` field no longer drives config.Agent.Model — model + // selection moved to the coordinator (buildAgent). ToConfigAgent + // inherits the base type regardless of the subagent's model alias. + name: "model_small_alias_inherits_base", + subagent: Subagent{ + Name: "my-agent", + Description: "Does something.", + Model: "small", + }, + base: config.Agent{ + Model: config.SelectedModelTypeLarge, + }, + check: func(t *testing.T, result config.Agent) { + t.Helper() + require.Equal(t, config.SelectedModelTypeLarge, result.Model) + }, + }, + { + name: "model_large_alias_inherits_base", + subagent: Subagent{ + Name: "my-agent", + Description: "Does something.", + Model: "large", + }, + base: config.Agent{ + Model: config.SelectedModelTypeSmall, + }, + check: func(t *testing.T, result config.Agent) { + t.Helper() + require.Equal(t, config.SelectedModelTypeSmall, result.Model) + }, + }, + { + name: "model_empty", + subagent: Subagent{ + Name: "my-agent", + Description: "Does something.", + Model: "", + }, + base: config.Agent{ + Model: config.SelectedModelTypeLarge, + }, + check: func(t *testing.T, result config.Agent) { + t.Helper() + require.Equal(t, config.SelectedModelTypeLarge, result.Model) + }, + }, + { + name: "id_and_name", + subagent: Subagent{ + Name: "my-agent", + Description: "Does something.", + }, + base: config.Agent{ + ID: "old-id", + Name: "old-name", + }, + check: func(t *testing.T, result config.Agent) { + t.Helper() + require.Equal(t, "my-agent", result.ID) + require.Equal(t, "my-agent", result.Name) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := tt.subagent.ToConfigAgent(tt.base) + tt.check(t, result) + }) + } +} diff --git a/internal/ui/chat/agent.go b/internal/ui/chat/agent.go index 8342ebaf83..74d0577963 100644 --- a/internal/ui/chat/agent.go +++ b/internal/ui/chat/agent.go @@ -123,20 +123,33 @@ type AgentToolRenderContext struct { agent *AgentToolMessageItem } +// agentRenderLabel returns the header label for an agent tool row: the plain +// "Agent" string for the default task agent, or "Agent: " when a +// specialized subagent has been dispatched. +func agentRenderLabel(subagentType string) string { + if subagentType == "" || subagentType == "task" { + return "Agent" + } + return "Agent: " + subagentType +} + // RenderTool implements the [ToolRenderer] interface. func (r *AgentToolRenderContext) RenderTool(sty *styles.Styles, width int, opts *ToolRenderOpts) string { cappedWidth := cappedMessageWidth(width) - if !opts.ToolCall.Finished && !opts.IsCanceled() && len(r.agent.nestedTools) == 0 { - return pendingTool(sty, "Agent", opts.Anim, opts.Compact) - } var params agent.AgentParams _ = json.Unmarshal([]byte(opts.ToolCall.Input), ¶ms) + label := agentRenderLabel(params.SubagentType) + + if !opts.ToolCall.Finished && !opts.IsCanceled() && len(r.agent.nestedTools) == 0 { + return pendingTool(sty, label, opts.Anim, opts.Compact) + } + prompt := params.Prompt prompt = strings.ReplaceAll(prompt, "\n", " ") - header := toolHeader(sty, opts.Status, "Agent", cappedWidth, opts.Compact) + header := toolHeader(sty, opts.Status, label, cappedWidth, opts.Compact) if opts.Compact { return header } diff --git a/internal/ui/chat/agent_label_test.go b/internal/ui/chat/agent_label_test.go new file mode 100644 index 0000000000..84a6d66acf --- /dev/null +++ b/internal/ui/chat/agent_label_test.go @@ -0,0 +1,29 @@ +package chat + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAgentRenderLabel(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + subagentType string + want string + }{ + {"empty_returns_agent", "", "Agent"}, + {"task_returns_agent", "task", "Agent"}, + {"named_subagent_prefixed", "code-reviewer", "Agent: code-reviewer"}, + {"another_named", "tester", "Agent: tester"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, agentRenderLabel(tt.subagentType)) + }) + } +} diff --git a/internal/ui/chat/tools.go b/internal/ui/chat/tools.go index 961173f30b..77ea4c6d10 100644 --- a/internal/ui/chat/tools.go +++ b/internal/ui/chat/tools.go @@ -1253,6 +1253,9 @@ func (t *baseToolMessageItem) formatParametersForCopy() string { case agent.AgentToolName: var params agent.AgentParams if json.Unmarshal([]byte(t.toolCall.Input), ¶ms) == nil { + if params.SubagentType != "" && params.SubagentType != "task" { + return fmt.Sprintf("**Subagent:** %s\n\n**Task:**\n%s", params.SubagentType, params.Prompt) + } return fmt.Sprintf("**Task:**\n%s", params.Prompt) } } diff --git a/internal/ui/completions/completions.go b/internal/ui/completions/completions.go index 9393cce524..e095edbbdd 100644 --- a/internal/ui/completions/completions.go +++ b/internal/ui/completions/completions.go @@ -42,6 +42,7 @@ type ClosedMsg struct{} type CompletionItemsLoadedMsg struct { Files []FileCompletionValue Resources []ResourceCompletionValue + Subagents []SubagentCompletionValue } // Completions represents the completions popup component. @@ -141,9 +142,12 @@ func (c *Completions) KeyMap() KeyMap { } // Open opens the completions with file items from the filesystem. -func (c *Completions) Open(depth, limit int) tea.Cmd { +// subagentItems are already in memory so they are passed directly rather than +// loaded asynchronously. +func (c *Completions) Open(depth, limit int, subagentItems []SubagentCompletionValue) tea.Cmd { return func() tea.Msg { var msg CompletionItemsLoadedMsg + msg.Subagents = subagentItems var wg sync.WaitGroup wg.Go(func() { msg.Files = loadFiles(depth, limit) @@ -156,11 +160,24 @@ func (c *Completions) Open(depth, limit int) tea.Cmd { } } -// SetItems sets the files and MCP resources and rebuilds the merged list. -func (c *Completions) SetItems(files []FileCompletionValue, resources []ResourceCompletionValue) { - items := make([]list.FilterableItem, 0, len(files)+len(resources)) +// SetItems sets the subagents, files and MCP resources and rebuilds the +// merged list. Subagents appear first so they sit at the top of the popup. +func (c *Completions) SetItems(files []FileCompletionValue, resources []ResourceCompletionValue, subagents []SubagentCompletionValue) { + items := make([]list.FilterableItem, 0, len(subagents)+len(files)+len(resources)) - // Add files first. + // Subagents appear first. + for _, sa := range subagents { + item := NewCompletionItem( + sa.Name, + sa, + c.normalStyle, + c.focusedStyle, + c.matchStyle, + ) + items = append(items, item) + } + + // Files. for _, file := range files { item := NewCompletionItem( file.Path, @@ -172,7 +189,7 @@ func (c *Completions) SetItems(files []FileCompletionValue, resources []Resource items = append(items, item) } - // Add MCP resources. + // MCP resources. for _, resource := range resources { item := NewCompletionItem( resource.MCPName+"/"+cmp.Or(resource.Title, resource.URI), @@ -385,6 +402,11 @@ func (c *Completions) selectCurrent(keepOpen bool) tea.Msg { Value: item, KeepOpen: keepOpen, } + case SubagentCompletionValue: + return SelectionMsg[SubagentCompletionValue]{ + Value: item, + KeepOpen: keepOpen, + } default: return nil } diff --git a/internal/ui/completions/completions_test.go b/internal/ui/completions/completions_test.go index 906fc3f4c3..66b6d8d45e 100644 --- a/internal/ui/completions/completions_test.go +++ b/internal/ui/completions/completions_test.go @@ -14,7 +14,7 @@ func TestFilterPrefersExactBasenameStem(t *testing.T) { c.SetItems([]FileCompletionValue{ {Path: "internal/ui/chat/search.go"}, {Path: "internal/ui/chat/user.go"}, - }, nil) + }, nil, nil) c.Filter("user") @@ -33,7 +33,7 @@ func TestFilterPrefersBasenamePrefix(t *testing.T) { c.SetItems([]FileCompletionValue{ {Path: "internal/ui/chat/mcp.go"}, {Path: "internal/ui/model/chat.go"}, - }, nil) + }, nil, nil) c.Filter("chat.g") @@ -96,7 +96,7 @@ func TestFilterPrefersPathSegmentExact(t *testing.T) { c.SetItems([]FileCompletionValue{ {Path: "internal/ui/model/xychat.go"}, {Path: "internal/ui/chat/mcp.go"}, - }, nil) + }, nil, nil) c.Filter("chat") @@ -106,3 +106,104 @@ func TestFilterPrefersPathSegmentExact(t *testing.T) { require.True(t, ok) require.Equal(t, "internal/ui/chat/mcp.go", first.Text()) } + +// TestSetItems_SubagentsAppearsInList verifies that a subagent passed via the +// third argument to SetItems is represented in the filtered list as an item +// whose Text() equals the subagent name. +func TestSetItems_SubagentsAppearsInList(t *testing.T) { + t.Parallel() + + c := New(lipgloss.NewStyle(), lipgloss.NewStyle(), lipgloss.NewStyle()) + c.SetItems(nil, nil, []SubagentCompletionValue{ + {Name: "code-reviewer", Description: "reviews code"}, + }) + + var found bool + for _, item := range c.filtered { + ci, ok := item.(*CompletionItem) + if !ok { + continue + } + if ci.Text() == "code-reviewer" { + found = true + break + } + } + require.True(t, found, "expected to find a completion item with text %q", "code-reviewer") +} + +// TestSetItems_SubagentAndFilesCoexist verifies that when SetItems is called +// with both file and subagent entries, items for both appear in the filtered +// list. +func TestSetItems_SubagentAndFilesCoexist(t *testing.T) { + t.Parallel() + + c := New(lipgloss.NewStyle(), lipgloss.NewStyle(), lipgloss.NewStyle()) + c.SetItems( + []FileCompletionValue{{Path: "cmd/main.go"}}, + nil, + []SubagentCompletionValue{{Name: "tester", Description: "writes tests"}}, + ) + + texts := make([]string, 0, len(c.filtered)) + for _, item := range c.filtered { + ci, ok := item.(*CompletionItem) + if !ok { + continue + } + texts = append(texts, ci.Text()) + } + + require.Contains(t, texts, "cmd/main.go", "file item must appear in filtered list") + require.Contains(t, texts, "tester", "subagent item must appear in filtered list") +} + +// TestSetItems_NilSubagents_NoError verifies that calling SetItems with a nil +// subagents slice does not panic and still populates file items normally. +func TestSetItems_NilSubagents_NoError(t *testing.T) { + t.Parallel() + + c := New(lipgloss.NewStyle(), lipgloss.NewStyle(), lipgloss.NewStyle()) + require.NotPanics(t, func() { + c.SetItems([]FileCompletionValue{{Path: "internal/foo.go"}}, nil, nil) + }) + + require.NotEmpty(t, c.filtered, "file items must still be populated when subagents is nil") + + var found bool + for _, item := range c.filtered { + ci, ok := item.(*CompletionItem) + if !ok { + continue + } + if ci.Text() == "internal/foo.go" { + found = true + break + } + } + require.True(t, found, "expected to find file item %q when subagents is nil", "internal/foo.go") +} + +// TestSetItems_PreservesSubagentOrder verifies that multiple subagents appear in +// the filtered list in the order they were passed to SetItems. This pins the +// ordering contract so frontend display matches input order. +func TestSetItems_PreservesSubagentOrder(t *testing.T) { + t.Parallel() + + c := New(lipgloss.NewStyle(), lipgloss.NewStyle(), lipgloss.NewStyle()) + c.SetItems(nil, nil, []SubagentCompletionValue{ + {Name: "zeta"}, + {Name: "alpha"}, + {Name: "mu"}, + }) + + require.Len(t, c.filtered, 3) + + got := make([]string, 0, 3) + for _, item := range c.filtered { + ci, ok := item.(*CompletionItem) + require.True(t, ok) + got = append(got, ci.Text()) + } + require.Equal(t, []string{"zeta", "alpha", "mu"}, got) +} diff --git a/internal/ui/completions/item.go b/internal/ui/completions/item.go index b149e58a50..052f8918e5 100644 --- a/internal/ui/completions/item.go +++ b/internal/ui/completions/item.go @@ -23,6 +23,12 @@ type ResourceCompletionValue struct { MIMEType string } +// SubagentCompletionValue represents a subagent @-mention completion value. +type SubagentCompletionValue struct { + Name string + Description string +} + // CompletionItem represents an item in the completions list. type CompletionItem struct { *list.Versioned diff --git a/internal/ui/dialog/actions.go b/internal/ui/dialog/actions.go index 0e96b06ad8..05023dd3b9 100644 --- a/internal/ui/dialog/actions.go +++ b/internal/ui/dialog/actions.go @@ -128,6 +128,11 @@ type ( } ) +// ActionLoadSubagentSession is a message to load a subagent's child session. +type ActionLoadSubagentSession struct { + SessionID string +} + // ActionCmd represents an action that carries a [tea.Cmd] to be passed to the // Bubble Tea program loop. type ActionCmd struct { diff --git a/internal/ui/dialog/commands.go b/internal/ui/dialog/commands.go index 6e17db70d0..e951bd4d6b 100644 --- a/internal/ui/dialog/commands.go +++ b/internal/ui/dialog/commands.go @@ -431,6 +431,7 @@ func (c *Commands) defaultCommands() []*CommandItem { commands := []*CommandItem{ NewCommandItem(c.com.Styles, "new_session", "New Session", "ctrl+n", ActionNewSession{}), NewCommandItem(c.com.Styles, "switch_session", "Sessions", "ctrl+s", ActionOpenDialog{SessionsID}), + NewCommandItem(c.com.Styles, "subagents", "Subagents", "ctrl+x", ActionOpenDialog{SubagentsID}), NewCommandItem(c.com.Styles, "switch_model", "Switch Model", "ctrl+l", ActionOpenDialog{ModelsID}), } diff --git a/internal/ui/dialog/subagents.go b/internal/ui/dialog/subagents.go new file mode 100644 index 0000000000..77d590c8c6 --- /dev/null +++ b/internal/ui/dialog/subagents.go @@ -0,0 +1,458 @@ +package dialog + +import ( + "charm.land/bubbles/v2/help" + "charm.land/bubbles/v2/key" + tea "charm.land/bubbletea/v2" + "github.com/charmbracelet/crush/internal/ui/common" + "github.com/charmbracelet/crush/internal/ui/list" + "github.com/charmbracelet/crush/internal/ui/util" + uv "github.com/charmbracelet/ultraviolet" +) + +// SubagentsID is the identifier for the subagents dialog. +const SubagentsID = "subagents" + +// SubagentsTab identifies which tab of the subagents dialog is active. +type SubagentsTab int + +// Possible tabs in the subagents dialog. +const ( + SubagentsTabRunning SubagentsTab = iota + SubagentsTabLibrary +) + +// Subagents is a dialog that shows running and library subagents. +type Subagents struct { + com *common.Common + tab SubagentsTab + parentSessionID string + runningList *list.FilterableList + libraryList *list.FilterableList + runningItems []*RunningSubagentItem + libraryItems []*LibrarySubagentItem + confirmDelete bool + + keyMap struct { + Tab key.Binding + Next key.Binding + Previous key.Binding + Enter key.Binding + Cancel key.Binding + Delete key.Binding + Toggle key.Binding + ConfirmDelete key.Binding + CancelDelete key.Binding + Close key.Binding + } + help help.Model +} + +var _ Dialog = (*Subagents)(nil) + +// NewSubagents creates a new [Subagents] dialog. It populates the running tab +// from com.Workspace.RunningSubagents(parentSessionID) and the library tab +// from com.Workspace.AllSubagents(). +func NewSubagents(com *common.Common, parentSessionID string) *Subagents { + s := &Subagents{ + com: com, + tab: SubagentsTabRunning, + parentSessionID: parentSessionID, + } + + h := help.New() + h.Styles = com.Styles.DialogHelpStyles() + s.help = h + + // Build running items. + running := com.Workspace.RunningSubagents(parentSessionID) + runningFilterable := make([]list.FilterableItem, len(running)) + s.runningItems = make([]*RunningSubagentItem, len(running)) + for i, r := range running { + item := NewRunningSubagentItem(com.Styles, RunningSubagentItemData{ + ChildSessionID: r.ChildSessionID, + Name: r.Name, + Color: r.Color, + Model: r.Model, + PromptTokens: r.PromptTokens, + CompletionTokens: r.CompletionTokens, + }) + s.runningItems[i] = item + runningFilterable[i] = item + } + s.runningList = list.NewFilterableList(runningFilterable...) + s.runningList.Focus() + s.runningList.SetSelected(0) + + // Build library items. + defs := com.Workspace.AllSubagents() + libraryFilterable := make([]list.FilterableItem, len(defs)) + s.libraryItems = make([]*LibrarySubagentItem, len(defs)) + for i, d := range defs { + item := NewLibrarySubagentItem(com.Styles, LibrarySubagentItemData{ + Name: d.Name, + Description: d.Description, + Color: d.Color, + FilePath: d.FilePath, + Scope: d.Scope, + Disabled: d.Disabled, + }) + s.libraryItems[i] = item + libraryFilterable[i] = item + } + s.libraryList = list.NewFilterableList(libraryFilterable...) + s.libraryList.SetSelected(0) + + s.keyMap.Tab = key.NewBinding( + key.WithKeys("tab", "shift+tab"), + key.WithHelp("tab", "switch tab"), + ) + s.keyMap.Next = key.NewBinding( + key.WithKeys("down", "ctrl+n"), + key.WithHelp("↓", "next item"), + ) + s.keyMap.Previous = key.NewBinding( + key.WithKeys("up", "ctrl+p"), + key.WithHelp("↑", "previous item"), + ) + s.keyMap.Enter = key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "select"), + ) + s.keyMap.Cancel = key.NewBinding( + key.WithKeys("x"), + key.WithHelp("x", "cancel subagent"), + ) + s.keyMap.Delete = key.NewBinding( + key.WithKeys("d"), + key.WithHelp("d", "delete"), + ) + s.keyMap.Toggle = key.NewBinding( + key.WithKeys("space"), + key.WithHelp("space", "enable/disable"), + ) + s.keyMap.ConfirmDelete = key.NewBinding( + key.WithKeys("y"), + key.WithHelp("y", "confirm delete"), + ) + s.keyMap.CancelDelete = key.NewBinding( + key.WithKeys("n", "esc"), + key.WithHelp("n", "cancel delete"), + ) + s.keyMap.Close = key.NewBinding( + key.WithKeys("esc", "alt+esc"), + key.WithHelp("esc", "close"), + ) + + return s +} + +// ID implements [Dialog]. +func (s *Subagents) ID() string { + return SubagentsID +} + +// ActiveTab returns the currently active tab. +func (s *Subagents) ActiveTab() SubagentsTab { + return s.tab +} + +// IsConfirmingDelete reports whether the dialog is in confirm-delete mode. +func (s *Subagents) IsConfirmingDelete() bool { + return s.confirmDelete +} + +// activeList returns the list for the currently active tab. +func (s *Subagents) activeList() *list.FilterableList { + if s.tab == SubagentsTabLibrary { + return s.libraryList + } + return s.runningList +} + +// HandleMsg implements [Dialog]. +func (s *Subagents) HandleMsg(msg tea.Msg) Action { + keyMsg, ok := msg.(tea.KeyPressMsg) + if !ok { + return nil + } + + // In confirm-delete mode, only accept y/n/esc. + if s.confirmDelete { + switch { + case key.Matches(keyMsg, s.keyMap.ConfirmDelete): + return s.confirmDeleteSelected() + case key.Matches(keyMsg, s.keyMap.CancelDelete): + s.confirmDelete = false + } + return nil + } + + switch { + case key.Matches(keyMsg, s.keyMap.Close): + return ActionClose{} + + case key.Matches(keyMsg, s.keyMap.Tab): + s.toggleTab() + + case key.Matches(keyMsg, s.keyMap.Previous): + l := s.activeList() + if l.IsSelectedFirst() { + l.SelectLast() + } else { + l.SelectPrev() + } + l.ScrollToSelected() + + case key.Matches(keyMsg, s.keyMap.Next): + l := s.activeList() + if l.IsSelectedLast() { + l.SelectFirst() + } else { + l.SelectNext() + } + l.ScrollToSelected() + + case s.tab == SubagentsTabRunning && key.Matches(keyMsg, s.keyMap.Enter): + return s.loadSelectedRunning() + + case s.tab == SubagentsTabRunning && key.Matches(keyMsg, s.keyMap.Cancel): + s.cancelSelectedRunning() + + case s.tab == SubagentsTabLibrary && key.Matches(keyMsg, s.keyMap.Toggle): + return s.toggleSelectedLibrary() + + case s.tab == SubagentsTabLibrary && key.Matches(keyMsg, s.keyMap.Delete): + s.enterConfirmDelete() + } + + return nil +} + +// toggleSelectedLibrary flips the enabled/disabled state of the selected +// library item, optimistically dimming/undimming it, and issues a cmd that +// persists the change via the workspace. +func (s *Subagents) toggleSelectedLibrary() Action { + item := s.libraryList.SelectedItem() + if item == nil { + return nil + } + li, ok := item.(*LibrarySubagentItem) + if !ok { + return nil + } + li.data.Disabled = !li.data.Disabled + li.Bump() + return ActionCmd{s.setDisabledCmd(li.ID(), li.data.Disabled)} +} + +// setDisabledCmd returns a cmd that persists the disabled state for name and +// reports any error back to the program. +func (s *Subagents) setDisabledCmd(name string, disabled bool) tea.Cmd { + return func() tea.Msg { + if err := s.com.Workspace.SetSubagentDisabled(name, disabled); err != nil { + return util.ReportError(err)() + } + return nil + } +} + +// toggleTab switches between the Running and Library tabs. +func (s *Subagents) toggleTab() { + if s.tab == SubagentsTabRunning { + s.tab = SubagentsTabLibrary + s.libraryList.Focus() + } else { + s.tab = SubagentsTabRunning + s.runningList.Focus() + } +} + +// loadSelectedRunning returns an [ActionLoadSubagentSession] for the currently +// selected running subagent, or nil if nothing is selected. +func (s *Subagents) loadSelectedRunning() Action { + item := s.runningList.SelectedItem() + if item == nil { + return nil + } + ri, ok := item.(*RunningSubagentItem) + if !ok { + return nil + } + return ActionLoadSubagentSession{SessionID: ri.ID()} +} + +// cancelSelectedRunning cancels the currently selected running subagent via +// the workspace and removes it from the list. +func (s *Subagents) cancelSelectedRunning() { + item := s.runningList.SelectedItem() + if item == nil { + return + } + ri, ok := item.(*RunningSubagentItem) + if !ok { + return + } + childID := ri.ID() + s.com.Workspace.CancelSubagent(childID) + s.removeRunningItem(childID) +} + +// removeRunningItem removes the running item with the given child session ID +// from the list. +func (s *Subagents) removeRunningItem(childID string) { + var newItems []*RunningSubagentItem + for _, item := range s.runningItems { + if item.ID() == childID { + continue + } + newItems = append(newItems, item) + } + s.runningItems = newItems + filterable := make([]list.FilterableItem, len(s.runningItems)) + for i, item := range s.runningItems { + filterable[i] = item + } + s.runningList.SetItems(filterable...) + s.runningList.SelectFirst() +} + +// enterConfirmDelete sets confirm-delete mode for the currently selected +// library item, if it has user scope. +func (s *Subagents) enterConfirmDelete() { + item := s.libraryList.SelectedItem() + if item == nil { + return + } + li, ok := item.(*LibrarySubagentItem) + if !ok { + return + } + if li.data.Scope != "user" { + return + } + s.confirmDelete = true +} + +// confirmDeleteSelected issues a delete cmd for the selected library item and +// removes it from the list optimistically. +func (s *Subagents) confirmDeleteSelected() Action { + s.confirmDelete = false + item := s.libraryList.SelectedItem() + if item == nil { + return nil + } + li, ok := item.(*LibrarySubagentItem) + if !ok { + return nil + } + name := li.ID() + s.removeLibraryItem(name) + return ActionCmd{s.deleteSubagentCmd(name)} +} + +// deleteSubagentCmd returns a cmd that calls DeleteUserSubagent and reports any +// error back to the program. +func (s *Subagents) deleteSubagentCmd(name string) tea.Cmd { + return func() tea.Msg { + if err := s.com.Workspace.DeleteUserSubagent(name); err != nil { + return util.ReportError(err)() + } + return nil + } +} + +// removeLibraryItem removes the library item with the given name from the list. +func (s *Subagents) removeLibraryItem(name string) { + var newItems []*LibrarySubagentItem + for _, item := range s.libraryItems { + if item.ID() == name { + continue + } + newItems = append(newItems, item) + } + s.libraryItems = newItems + filterable := make([]list.FilterableItem, len(s.libraryItems)) + for i, item := range s.libraryItems { + filterable[i] = item + } + s.libraryList.SetItems(filterable...) + s.libraryList.SelectFirst() +} + +// Draw implements [Dialog]. +func (s *Subagents) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor { + t := s.com.Styles + width := max(0, min(defaultDialogMaxWidth, area.Dx()-t.Dialog.View.GetHorizontalBorderSize())) + height := max(0, min(defaultDialogHeight, area.Dy()-t.Dialog.View.GetVerticalBorderSize())) + innerWidth := width - t.Dialog.View.GetHorizontalFrameSize() + + heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight + + t.Dialog.HelpView.GetVerticalFrameSize() + + t.Dialog.View.GetVerticalFrameSize() + listHeight := height - heightOffset + listWidth := max(0, innerWidth-3) + + l := s.activeList() + l.SetSize(listWidth, listHeight) + s.help.SetWidth(innerWidth) + + rc := NewRenderContext(t, width) + rc.Title = "Subagents" + + // Build tab indicator for title info. + runningLabel := "Running" + libraryLabel := "Library" + var tabInfo string + if s.tab == SubagentsTabRunning { + tabInfo = t.Dialog.SelectedItem.Render(runningLabel) + " | " + libraryLabel + } else { + tabInfo = runningLabel + " | " + t.Dialog.SelectedItem.Render(libraryLabel) + } + rc.TitleInfo = " " + tabInfo + + listView := t.Dialog.List.Height(l.Height()).Render(l.Render()) + rc.AddPart(listView) + rc.Help = s.help.View(s) + + view := rc.Render() + DrawCenter(scr, area, view) + return nil +} + +// ShortHelp implements [help.KeyMap]. +func (s *Subagents) ShortHelp() []key.Binding { + if s.confirmDelete { + return []key.Binding{ + s.keyMap.ConfirmDelete, + s.keyMap.CancelDelete, + } + } + if s.tab == SubagentsTabRunning { + return []key.Binding{ + s.keyMap.Next, + s.keyMap.Enter, + s.keyMap.Cancel, + s.keyMap.Tab, + s.keyMap.Close, + } + } + return []key.Binding{ + s.keyMap.Next, + s.keyMap.Toggle, + s.keyMap.Delete, + s.keyMap.Tab, + s.keyMap.Close, + } +} + +// FullHelp implements [help.KeyMap]. +func (s *Subagents) FullHelp() [][]key.Binding { + bindings := s.ShortHelp() + var out [][]key.Binding + for i := 0; i < len(bindings); i += 4 { + end := min(i+4, len(bindings)) + out = append(out, bindings[i:end]) + } + return out +} diff --git a/internal/ui/dialog/subagents_library_item.go b/internal/ui/dialog/subagents_library_item.go new file mode 100644 index 0000000000..02e653ac26 --- /dev/null +++ b/internal/ui/dialog/subagents_library_item.go @@ -0,0 +1,114 @@ +package dialog + +import ( + "github.com/charmbracelet/crush/internal/ui/list" + "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/charmbracelet/x/ansi" + "github.com/sahilm/fuzzy" +) + +// LibrarySubagentItemData holds the data for a library subagent list item. +type LibrarySubagentItemData struct { + Name string + Description string + Color string + FilePath string + Scope string + Disabled bool +} + +// LibrarySubagentItem wraps [LibrarySubagentItemData] to implement the +// [ListItem] interface for display in the subagents dialog library tab. +type LibrarySubagentItem struct { + *list.Versioned + t *styles.Styles + data LibrarySubagentItemData + m fuzzy.Match + focused bool +} + +var _ ListItem = &LibrarySubagentItem{Versioned: list.NewVersioned()} + +// NewLibrarySubagentItem creates a new [LibrarySubagentItem]. +func NewLibrarySubagentItem(t *styles.Styles, data LibrarySubagentItemData) *LibrarySubagentItem { + return &LibrarySubagentItem{ + Versioned: list.NewVersioned(), + t: t, + data: data, + } +} + +// Finished implements list.Item. Library subagent items are considered stable +// outside of explicit state mutations. +func (l *LibrarySubagentItem) Finished() bool { + return true +} + +// Filter implements [list.FilterableItem]. +func (l *LibrarySubagentItem) Filter() string { + return l.data.Name +} + +// ID implements [ListItem]. +func (l *LibrarySubagentItem) ID() string { + return l.data.Name +} + +// SetFocused implements [list.Focusable]. +func (l *LibrarySubagentItem) SetFocused(focused bool) { + if l.focused == focused { + return + } + l.focused = focused + if l.Versioned != nil { + l.Bump() + } +} + +// SetMatch implements [list.MatchSettable]. +func (l *LibrarySubagentItem) SetMatch(m fuzzy.Match) { + if sameFuzzyMatch(l.m, m) { + return + } + l.m = m + if l.Versioned != nil { + l.Bump() + } +} + +// Render implements list.Item. It renders the library subagent as two lines: +// the first line shows the colored dot, name, and scope badge; the second line +// shows the description. Disabled items are rendered with a dimmed style. +func (l *LibrarySubagentItem) Render(width int) string { + dot := styles.SubagentDot(l.data.Color) + + itemStyle := l.t.Dialog.NormalItem + if l.focused { + itemStyle = l.t.Dialog.SelectedItem + } + + innerWidth := max(0, width-itemStyle.GetHorizontalFrameSize()) + + scope := l.data.Scope + if scope == "" { + scope = "user" + } + + firstLine := dot + " " + l.data.Name + " " + scope + firstLine = ansi.Truncate(firstLine, innerWidth, "…") + + var content string + if l.data.Description != "" { + desc := ansi.Truncate(l.data.Description, innerWidth, "…") + content = firstLine + "\n" + desc + } else { + content = firstLine + } + + if l.data.Disabled { + dimStyle := l.t.Resource.AdditionalText + return dimStyle.Render(content) + } + + return itemStyle.Render(content) +} diff --git a/internal/ui/dialog/subagents_library_item_test.go b/internal/ui/dialog/subagents_library_item_test.go new file mode 100644 index 0000000000..5e8da58d54 --- /dev/null +++ b/internal/ui/dialog/subagents_library_item_test.go @@ -0,0 +1,66 @@ +package dialog + +import ( + "testing" + + uistyles "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/stretchr/testify/require" +) + +// TestLibrarySubagentItem_RenderContainsName verifies that the rendered output +// of a LibrarySubagentItem contains the agent name. +func TestLibrarySubagentItem_RenderContainsName(t *testing.T) { + t.Parallel() + + st := uistyles.CharmtonePantera() + item := NewLibrarySubagentItem(&st, LibrarySubagentItemData{ + Name: "my-agent", + Description: "does stuff", + Scope: "user", + }) + + rendered := item.Render(60) + plain := stripANSIDialog(rendered) + + require.Contains(t, plain, "my-agent") +} + +// TestLibrarySubagentItem_RenderContainsScopeBadge verifies that the rendered +// output contains the scope badge text for the item's scope. +func TestLibrarySubagentItem_RenderContainsScopeBadge(t *testing.T) { + t.Parallel() + + st := uistyles.CharmtonePantera() + item := NewLibrarySubagentItem(&st, LibrarySubagentItemData{ + Name: "my-agent", + Description: "does stuff", + Scope: "user", + }) + + rendered := item.Render(60) + plain := stripANSIDialog(rendered) + + require.Contains(t, plain, "user") +} + +// TestLibrarySubagentItem_DisabledItemRendered verifies that rendering a +// disabled item does not panic and still contains the agent name. +func TestLibrarySubagentItem_DisabledItemRendered(t *testing.T) { + t.Parallel() + + st := uistyles.CharmtonePantera() + item := NewLibrarySubagentItem(&st, LibrarySubagentItemData{ + Name: "my-agent", + Description: "does stuff", + Scope: "project", + Disabled: true, + }) + + var rendered string + require.NotPanics(t, func() { + rendered = item.Render(60) + }) + + plain := stripANSIDialog(rendered) + require.Contains(t, plain, "my-agent") +} diff --git a/internal/ui/dialog/subagents_running_item.go b/internal/ui/dialog/subagents_running_item.go new file mode 100644 index 0000000000..500162524e --- /dev/null +++ b/internal/ui/dialog/subagents_running_item.go @@ -0,0 +1,96 @@ +package dialog + +import ( + "fmt" + + "github.com/charmbracelet/crush/internal/ui/list" + "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/charmbracelet/x/ansi" + "github.com/sahilm/fuzzy" +) + +// RunningSubagentItemData holds the data for a running subagent list item. +type RunningSubagentItemData struct { + ChildSessionID string + Name string + Color string + Model string + PromptTokens int64 + CompletionTokens int64 +} + +// RunningSubagentItem wraps [RunningSubagentItemData] to implement the +// [ListItem] interface for display in the subagents dialog running tab. +type RunningSubagentItem struct { + *list.Versioned + t *styles.Styles + data RunningSubagentItemData + m fuzzy.Match + focused bool +} + +var _ ListItem = &RunningSubagentItem{Versioned: list.NewVersioned()} + +// NewRunningSubagentItem creates a new [RunningSubagentItem]. +func NewRunningSubagentItem(t *styles.Styles, data RunningSubagentItemData) *RunningSubagentItem { + return &RunningSubagentItem{ + Versioned: list.NewVersioned(), + t: t, + data: data, + } +} + +// Finished implements list.Item. Running subagent items are considered stable +// outside of explicit state mutations. +func (r *RunningSubagentItem) Finished() bool { + return true +} + +// Filter implements [list.FilterableItem]. +func (r *RunningSubagentItem) Filter() string { + return r.data.Name +} + +// ID implements [ListItem]. +func (r *RunningSubagentItem) ID() string { + return r.data.ChildSessionID +} + +// SetFocused implements [list.Focusable]. +func (r *RunningSubagentItem) SetFocused(focused bool) { + if r.focused == focused { + return + } + r.focused = focused + if r.Versioned != nil { + r.Bump() + } +} + +// SetMatch implements [list.MatchSettable]. +func (r *RunningSubagentItem) SetMatch(m fuzzy.Match) { + if sameFuzzyMatch(r.m, m) { + return + } + r.m = m + if r.Versioned != nil { + r.Bump() + } +} + +// Render implements list.Item. It renders the running subagent as a single +// line showing the colored dot, name, model, and total token count. +func (r *RunningSubagentItem) Render(width int) string { + dot := styles.SubagentDot(r.data.Color) + totalTokens := r.data.PromptTokens + r.data.CompletionTokens + tokStr := fmt.Sprintf("%d tok", totalTokens) + + itemStyle := r.t.Dialog.NormalItem + if r.focused { + itemStyle = r.t.Dialog.SelectedItem + } + + content := dot + " " + r.data.Name + " " + r.data.Model + " " + tokStr + content = ansi.Truncate(content, max(0, width-itemStyle.GetHorizontalFrameSize()), "…") + return itemStyle.Render(content) +} diff --git a/internal/ui/dialog/subagents_running_item_test.go b/internal/ui/dialog/subagents_running_item_test.go new file mode 100644 index 0000000000..018a6082b7 --- /dev/null +++ b/internal/ui/dialog/subagents_running_item_test.go @@ -0,0 +1,69 @@ +package dialog + +import ( + "testing" + + uistyles "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/stretchr/testify/require" +) + +// TestRunningSubagentItem_RenderContainsName verifies that the rendered output +// of a RunningSubagentItem contains the agent name and model string. +func TestRunningSubagentItem_RenderContainsName(t *testing.T) { + t.Parallel() + + st := uistyles.CharmtonePantera() + item := NewRunningSubagentItem(&st, RunningSubagentItemData{ + Name: "my-agent", + Color: "blue", + Model: "claude-opus-4-7", + PromptTokens: 100, + CompletionTokens: 50, + }) + + rendered := item.Render(60) + plain := stripANSIDialog(rendered) + + require.Contains(t, plain, "my-agent") + require.Contains(t, plain, "claude-opus-4-7") +} + +// TestRunningSubagentItem_RenderContainsTokenCount verifies that the rendered +// output contains the sum of prompt and completion tokens formatted as "N tok". +func TestRunningSubagentItem_RenderContainsTokenCount(t *testing.T) { + t.Parallel() + + st := uistyles.CharmtonePantera() + item := NewRunningSubagentItem(&st, RunningSubagentItemData{ + Name: "my-agent", + Color: "blue", + Model: "claude-opus-4-7", + PromptTokens: 100, + CompletionTokens: 50, + }) + + rendered := item.Render(60) + plain := stripANSIDialog(rendered) + + require.Contains(t, plain, "150 tok") +} + +// TestRunningSubagentItem_RenderContainsDot verifies that the rendered output +// contains the colored dot produced by styles.SubagentDot for the item's color. +func TestRunningSubagentItem_RenderContainsDot(t *testing.T) { + t.Parallel() + + st := uistyles.CharmtonePantera() + item := NewRunningSubagentItem(&st, RunningSubagentItemData{ + Name: "my-agent", + Color: "blue", + Model: "claude-opus-4-7", + PromptTokens: 100, + CompletionTokens: 50, + }) + + rendered := item.Render(60) + dot := uistyles.SubagentDot("blue") + + require.Contains(t, rendered, dot) +} diff --git a/internal/ui/dialog/subagents_test.go b/internal/ui/dialog/subagents_test.go new file mode 100644 index 0000000000..ccc37f18b7 --- /dev/null +++ b/internal/ui/dialog/subagents_test.go @@ -0,0 +1,243 @@ +package dialog + +import ( + "strings" + "testing" + + tea "charm.land/bubbletea/v2" + "github.com/charmbracelet/crush/internal/ui/common" + "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/charmbracelet/crush/internal/workspace" + "github.com/stretchr/testify/require" +) + +// subagentsWorkspace stubs only the workspace methods exercised by the +// Subagents dialog. +type subagentsWorkspace struct { + workspace.Workspace + running []workspace.RunningSubagentInfo + defs []workspace.SubagentDefInfo + cancelledIDs []string + deletedNames []string + deleteUserErr error + disabledCalls []disabledCall +} + +type disabledCall struct { + name string + disabled bool +} + +func (w *subagentsWorkspace) SetSubagentDisabled(name string, disabled bool) error { + w.disabledCalls = append(w.disabledCalls, disabledCall{name: name, disabled: disabled}) + return nil +} + +func (w *subagentsWorkspace) RunningSubagents(_ string) []workspace.RunningSubagentInfo { + return w.running +} + +func (w *subagentsWorkspace) AllSubagents() []workspace.SubagentDefInfo { + return w.defs +} + +func (w *subagentsWorkspace) CancelSubagent(childSessionID string) { + w.cancelledIDs = append(w.cancelledIDs, childSessionID) +} + +func (w *subagentsWorkspace) DeleteUserSubagent(name string) error { + w.deletedNames = append(w.deletedNames, name) + return w.deleteUserErr +} + +func newTestSubagentsDialog(t *testing.T, ws *subagentsWorkspace) *Subagents { + t.Helper() + st := styles.CharmtonePantera() + com := &common.Common{Styles: &st, Workspace: ws} + return NewSubagents(com, "parent-session-id") +} + +// TestSubagentsDialog_ImplementsDialogInterface is a compile-time assertion. +var _ Dialog = (*Subagents)(nil) + +// TestSubagentsDialog_TabToggle verifies that tab key toggles between +// Running and Library tabs and that a second tab returns to Running. +func TestSubagentsDialog_TabToggle(t *testing.T) { + t.Parallel() + + ws := &subagentsWorkspace{ + running: []workspace.RunningSubagentInfo{ + {ChildSessionID: "child-1", Name: "agent-one", Color: "blue", Model: "claude-opus-4-7"}, + }, + defs: []workspace.SubagentDefInfo{ + {Name: "lib-agent", Scope: "user"}, + }, + } + d := newTestSubagentsDialog(t, ws) + + require.Equal(t, SubagentsTabRunning, d.ActiveTab(), "initial tab should be Running") + + d.HandleMsg(tea.KeyPressMsg{Code: tea.KeyTab}) + require.Equal(t, SubagentsTabLibrary, d.ActiveTab(), "after one tab, should be Library") + + d.HandleMsg(tea.KeyPressMsg{Code: tea.KeyTab}) + require.Equal(t, SubagentsTabRunning, d.ActiveTab(), "after two tabs, should return to Running") +} + +// TestSubagentsDialog_EnterOnRunningItem verifies that pressing enter on +// a running subagent row returns ActionLoadSubagentSession with the correct +// child session ID. +func TestSubagentsDialog_EnterOnRunningItem(t *testing.T) { + t.Parallel() + + ws := &subagentsWorkspace{ + running: []workspace.RunningSubagentInfo{ + {ChildSessionID: "child-session-42", Name: "my-agent", Color: "red", Model: "claude-opus-4-7"}, + }, + } + d := newTestSubagentsDialog(t, ws) + + action := d.HandleMsg(tea.KeyPressMsg{Code: tea.KeyEnter}) + + loaded, ok := action.(ActionLoadSubagentSession) + require.True(t, ok, "enter on running item should return ActionLoadSubagentSession, got %T", action) + require.Equal(t, "child-session-42", loaded.SessionID) +} + +// TestSubagentsDialog_XCancelsRunningSubagent verifies that pressing x on a +// running subagent row calls CancelSubagent with the child session ID. +func TestSubagentsDialog_XCancelsRunningSubagent(t *testing.T) { + t.Parallel() + + ws := &subagentsWorkspace{ + running: []workspace.RunningSubagentInfo{ + {ChildSessionID: "child-cancel-me", Name: "cancellable-agent", Color: "green", Model: "claude-sonnet"}, + }, + } + d := newTestSubagentsDialog(t, ws) + + d.HandleMsg(keyMsg('x')) + + require.Contains(t, ws.cancelledIDs, "child-cancel-me", "CancelSubagent must be called with child session ID") +} + +// TestSubagentsDialog_EscReturnsActionClose verifies that pressing esc +// returns ActionClose{}. +func TestSubagentsDialog_EscReturnsActionClose(t *testing.T) { + t.Parallel() + + ws := &subagentsWorkspace{} + d := newTestSubagentsDialog(t, ws) + + action := d.HandleMsg(tea.KeyPressMsg{Code: tea.KeyEscape}) + + _, ok := action.(ActionClose) + require.True(t, ok, "esc should return ActionClose{}, got %T", action) +} + +// TestSubagentsDialog_DeleteLibraryItem verifies that pressing d on a +// user-scoped library item enters confirm-delete mode, and pressing y +// calls DeleteUserSubagent with the item name. +func TestSubagentsDialog_DeleteLibraryItem(t *testing.T) { + t.Parallel() + + ws := &subagentsWorkspace{ + defs: []workspace.SubagentDefInfo{ + {Name: "user-agent", Description: "does stuff", Scope: "user"}, + }, + } + d := newTestSubagentsDialog(t, ws) + + // Navigate to Library tab first. + d.HandleMsg(tea.KeyPressMsg{Code: tea.KeyTab}) + require.Equal(t, SubagentsTabLibrary, d.ActiveTab()) + + // Press d to enter confirm-delete mode. + d.HandleMsg(keyMsg('d')) + require.True(t, d.IsConfirmingDelete(), "pressing d should enter confirm-delete mode") + + // Press y to confirm deletion; execute the returned cmd to drive the IO. + action := d.HandleMsg(keyMsg('y')) + if ac, ok := action.(ActionCmd); ok && ac.Cmd != nil { + ac.Cmd() + } + require.Contains(t, ws.deletedNames, "user-agent", "DeleteUserSubagent must be called with agent name") +} + +// TestSubagentsDialog_DeleteLibraryItem_Cancel verifies that pressing d +// then n cancels the deletion without calling DeleteUserSubagent. +func TestSubagentsDialog_DeleteLibraryItem_Cancel(t *testing.T) { + t.Parallel() + + ws := &subagentsWorkspace{ + defs: []workspace.SubagentDefInfo{ + {Name: "user-agent", Description: "does stuff", Scope: "user"}, + }, + } + d := newTestSubagentsDialog(t, ws) + + // Navigate to Library tab. + d.HandleMsg(tea.KeyPressMsg{Code: tea.KeyTab}) + + // Enter confirm-delete mode. + d.HandleMsg(keyMsg('d')) + require.True(t, d.IsConfirmingDelete()) + + // Cancel with n. + d.HandleMsg(keyMsg('n')) + require.False(t, d.IsConfirmingDelete(), "pressing n should exit confirm-delete mode") + require.Empty(t, ws.deletedNames, "DeleteUserSubagent must not be called when deletion is cancelled") +} + +// TestSubagentsDialog_ToggleLibraryItem verifies that pressing space on a +// library item toggles its disabled state, calling SetSubagentDisabled with +// alternating values (disable then re-enable). +func TestSubagentsDialog_ToggleLibraryItem(t *testing.T) { + t.Parallel() + + ws := &subagentsWorkspace{ + defs: []workspace.SubagentDefInfo{ + {Name: "lib-agent", Description: "does stuff", Scope: "user", Disabled: false}, + }, + } + d := newTestSubagentsDialog(t, ws) + + d.HandleMsg(tea.KeyPressMsg{Code: tea.KeyTab}) + require.Equal(t, SubagentsTabLibrary, d.ActiveTab()) + + runCmd := func(action Action) { + if ac, ok := action.(ActionCmd); ok && ac.Cmd != nil { + ac.Cmd() + } + } + + runCmd(d.HandleMsg(keyMsg(' '))) + require.Len(t, ws.disabledCalls, 1) + require.Equal(t, "lib-agent", ws.disabledCalls[0].name) + require.True(t, ws.disabledCalls[0].disabled, "first toggle must disable") + + runCmd(d.HandleMsg(keyMsg(' '))) + require.Len(t, ws.disabledCalls, 2) + require.False(t, ws.disabledCalls[1].disabled, "second toggle must re-enable") +} + +// stripANSIDialog strips ANSI escape sequences from a string for plain-text +// assertions in dialog tests. +func stripANSIDialog(s string) string { + var b strings.Builder + esc := false + for i := 0; i < len(s); i++ { + if s[i] == '\x1b' { + esc = true + continue + } + if esc { + if (s[i] >= 'a' && s[i] <= 'z') || (s[i] >= 'A' && s[i] <= 'Z') { + esc = false + } + continue + } + b.WriteByte(s[i]) + } + return b.String() +} diff --git a/internal/ui/model/breadcrumb.go b/internal/ui/model/breadcrumb.go new file mode 100644 index 0000000000..e42dd022cf --- /dev/null +++ b/internal/ui/model/breadcrumb.go @@ -0,0 +1,24 @@ +package model + +import ( + "charm.land/lipgloss/v2" + "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/charmbracelet/x/ansi" +) + +// parentBreadcrumbLine renders a one-line breadcrumb indicating the parent +// session. It returns an empty string when title is empty. The dot is colored +// with the subagent's palette color. The rendered line fits within width +// terminal columns. +func parentBreadcrumbLine(t *styles.Styles, color, title string, width int) string { + if title == "" { + return "" + } + + dot := styles.SubagentDot(color) + prefix := " ↑ parent: " + maxTitleWidth := max(width-lipgloss.Width(dot)-lipgloss.Width(prefix), 0) + truncated := ansi.Truncate(title, maxTitleWidth, "…") + + return dot + t.Resource.AdditionalText.Render(prefix+truncated) +} diff --git a/internal/ui/model/breadcrumb_test.go b/internal/ui/model/breadcrumb_test.go new file mode 100644 index 0000000000..aa2f9c86cb --- /dev/null +++ b/internal/ui/model/breadcrumb_test.go @@ -0,0 +1,49 @@ +package model + +import ( + "strings" + "testing" + + "charm.land/bubbles/v2/key" + "charm.land/lipgloss/v2" + uistyles "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/stretchr/testify/require" +) + +func TestParentBreadcrumbLine_EmptyTitle(t *testing.T) { + t.Parallel() + + st := uistyles.CharmtonePantera() + got := parentBreadcrumbLine(&st, "", "", 40) + require.Empty(t, got) +} + +func TestParentBreadcrumbLine_WithTitle(t *testing.T) { + t.Parallel() + + st := uistyles.CharmtonePantera() + got := parentBreadcrumbLine(&st, "blue", "Main Session", 40) + require.Contains(t, stripANSI(got), "↑ parent: Main Session") +} + +func TestParentBreadcrumbLine_WidthRespected(t *testing.T) { + t.Parallel() + + st := uistyles.CharmtonePantera() + longTitle := "This Is An Extremely Long Session Title That Will Not Fit" + got := parentBreadcrumbLine(&st, "", longTitle, 20) + for line := range strings.SplitSeq(got, "\n") { + require.LessOrEqual(t, lipgloss.Width(line), 20, "line exceeds width: %q", line) + } +} + +func TestDefaultKeyMap_HasParentSessionBinding(t *testing.T) { + t.Parallel() + + km := DefaultKeyMap() + require.True(t, km.ParentSession.Enabled(), "ParentSession binding should be enabled") + require.Contains(t, km.ParentSession.Keys(), "ctrl+up") +} + +// Verify that key.Binding is the type we're using (compile-time check). +var _ key.Binding = DefaultKeyMap().ParentSession diff --git a/internal/ui/model/keys.go b/internal/ui/model/keys.go index ebf377035e..00b08fd004 100644 --- a/internal/ui/model/keys.go +++ b/internal/ui/model/keys.go @@ -57,14 +57,16 @@ type KeyMap struct { } // Global key maps - Quit key.Binding - Help key.Binding - Commands key.Binding - Models key.Binding - Suspend key.Binding - Sessions key.Binding - Tab key.Binding - ToggleYolo key.Binding + Quit key.Binding + Help key.Binding + Commands key.Binding + Models key.Binding + Suspend key.Binding + Sessions key.Binding + Tab key.Binding + ToggleYolo key.Binding + ParentSession key.Binding + Subagents key.Binding } func DefaultKeyMap() KeyMap { @@ -101,6 +103,14 @@ func DefaultKeyMap() KeyMap { key.WithKeys("ctrl+y"), key.WithHelp("ctrl+y", "toggle yolo"), ), + ParentSession: key.NewBinding( + key.WithKeys("ctrl+up"), + key.WithHelp("ctrl+up", "go to parent session"), + ), + Subagents: key.NewBinding( + key.WithKeys("ctrl+x"), + key.WithHelp("ctrl+x", "subagents"), + ), } km.Editor.AddFile = key.NewBinding( diff --git a/internal/ui/model/keys_test.go b/internal/ui/model/keys_test.go new file mode 100644 index 0000000000..296cf1ab50 --- /dev/null +++ b/internal/ui/model/keys_test.go @@ -0,0 +1,21 @@ +package model + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestDefaultKeyMap_HasSubagentsBinding verifies that the global KeyMap +// includes an enabled Subagents binding bound to ctrl+x. ctrl+x avoids the +// readline start-of-line collision that ctrl+a caused in the editor. +func TestDefaultKeyMap_HasSubagentsBinding(t *testing.T) { + t.Parallel() + + km := DefaultKeyMap() + + require.True(t, km.Subagents.Enabled(), "Subagents binding should be enabled") + require.Contains(t, km.Subagents.Keys(), "ctrl+x") + require.NotContains(t, km.Subagents.Keys(), "ctrl+a", + "ctrl+a collides with readline start-of-line in the editor") +} diff --git a/internal/ui/model/session.go b/internal/ui/model/session.go index aa31009b89..4099a3bfc1 100644 --- a/internal/ui/model/session.go +++ b/internal/ui/model/session.go @@ -95,6 +95,35 @@ func (m *UI) loadSession(sessionID string) tea.Cmd { return tea.Batch(load, m.reportCurrentSession(sessionID)) } +// fetchParentMeta fetches the parent session title (for the breadcrumb) and +// this child's subagent color (from the in-memory runtime). Both lookups run +// off the Update path. childSessionID may be empty when only the title matters. +func (m *UI) fetchParentMeta(parentSessionID, childSessionID string) tea.Cmd { + return func() tea.Msg { + sess, err := m.com.Workspace.GetSession(context.Background(), parentSessionID) + if err != nil { + return nil + } + var color string + for _, entry := range m.com.Workspace.RunningSubagents(parentSessionID) { + if entry.ChildSessionID == childSessionID { + color = entry.Color + break + } + } + return parentTitleMsg{title: sess.Title, color: color} + } +} + +// refreshRunningSubagents resolves the running-subagent list for sessionID off +// the Update path (RunningSubagents hits the DB to enrich token counts) and +// delivers it as a runningSubagentsMsg. +func (m *UI) refreshRunningSubagents(sessionID string) tea.Cmd { + return func() tea.Msg { + return runningSubagentsMsg{list: m.com.Workspace.RunningSubagents(sessionID)} + } +} + // reportCurrentSession returns a fire-and-forget tea.Cmd that // informs the workspace which session this client is currently // viewing. Errors are logged at debug only; the call is a hint diff --git a/internal/ui/model/session_test.go b/internal/ui/model/session_test.go index 7aa1a3f9f3..0f43fcd804 100644 --- a/internal/ui/model/session_test.go +++ b/internal/ui/model/session_test.go @@ -1,12 +1,17 @@ package model import ( + "context" + "errors" "strings" "testing" "charm.land/lipgloss/v2" "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/session" + "github.com/charmbracelet/crush/internal/ui/common" "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/charmbracelet/crush/internal/workspace" "github.com/stretchr/testify/require" ) @@ -33,7 +38,7 @@ func TestFileList(t *testing.T) { } got := fileList(st, "/", files, 10, 10) plain := stripANSI(got) - for _, line := range strings.Split(plain, "\n") { + for line := range strings.SplitSeq(plain, "\n") { require.LessOrEqual(t, lipgloss.Width(line), 10, "line exceeds sidebar width: %q", line) } }) @@ -49,7 +54,7 @@ func TestFileList(t *testing.T) { plain := stripANSI(got) require.Contains(t, plain, "+5") require.Contains(t, plain, "-3") - for _, line := range strings.Split(plain, "\n") { + for line := range strings.SplitSeq(plain, "\n") { require.LessOrEqual(t, lipgloss.Width(line), 20, "line exceeds sidebar width: %q", line) } }) @@ -78,7 +83,7 @@ func TestFileList(t *testing.T) { plain := stripANSI(got) require.Contains(t, plain, "+3") require.NotContains(t, plain, "-0") - for _, line := range strings.Split(plain, "\n") { + for line := range strings.SplitSeq(plain, "\n") { require.LessOrEqual(t, lipgloss.Width(line), 20, "line exceeds sidebar width: %q", line) } }) @@ -94,7 +99,7 @@ func TestFileList(t *testing.T) { plain := stripANSI(got) require.NotContains(t, plain, "+0") require.Contains(t, plain, "-7") - for _, line := range strings.Split(plain, "\n") { + for line := range strings.SplitSeq(plain, "\n") { require.LessOrEqual(t, lipgloss.Width(line), 20, "line exceeds sidebar width: %q", line) } }) @@ -140,3 +145,56 @@ func stripANSI(s string) string { } return b.String() } + +func TestFetchParentMeta_ReturnsTitleAndColor(t *testing.T) { + t.Parallel() + + ws := &getSessionWorkspace{ + sessions: map[string]session.Session{ + "parent-id": {ID: "parent-id", Title: "My Parent Session"}, + }, + running: map[string][]workspace.RunningSubagentInfo{ + "parent-id": {{ChildSessionID: "child-id", Color: "purple"}}, + }, + } + m := &UI{com: &common.Common{Workspace: ws}} + + cmd := m.fetchParentMeta("parent-id", "child-id") + msg := cmd() + + ptm, ok := msg.(parentTitleMsg) + require.True(t, ok, "expected parentTitleMsg") + require.Equal(t, "My Parent Session", ptm.title) + require.Equal(t, "purple", ptm.color) +} + +func TestFetchParentMeta_NotFoundReturnsNil(t *testing.T) { + t.Parallel() + + ws := &getSessionWorkspace{sessions: map[string]session.Session{}} + m := &UI{com: &common.Common{Workspace: ws}} + + cmd := m.fetchParentMeta("missing", "") + msg := cmd() + + require.Nil(t, msg) +} + +// getSessionWorkspace stubs GetSession and RunningSubagents for the +// fetchParentMeta tests. +type getSessionWorkspace struct { + workspace.Workspace + sessions map[string]session.Session + running map[string][]workspace.RunningSubagentInfo +} + +func (w *getSessionWorkspace) RunningSubagents(parentSessionID string) []workspace.RunningSubagentInfo { + return w.running[parentSessionID] +} + +func (w *getSessionWorkspace) GetSession(_ context.Context, sessionID string) (session.Session, error) { + if sess, ok := w.sessions[sessionID]; ok { + return sess, nil + } + return session.Session{}, errors.New("not found") +} diff --git a/internal/ui/model/sidebar.go b/internal/ui/model/sidebar.go index e98ef33423..6a8e16fb02 100644 --- a/internal/ui/model/sidebar.go +++ b/internal/ui/model/sidebar.go @@ -150,12 +150,18 @@ func (m *UI) drawSidebar(scr uv.Screen, area uv.Rectangle) { blocks := []string{ sidebarLogo, title, + } + if bc := parentBreadcrumbLine(t, m.subagentColor, m.parentTitle, width); bc != "" { + blocks = append(blocks, bc) + } + blocks = append( + blocks, "", cwd, "", m.modelInfo(width), "", - } + ) sidebarHeader := lipgloss.JoinVertical( lipgloss.Left, diff --git a/internal/ui/model/subagent_rewrite.go b/internal/ui/model/subagent_rewrite.go new file mode 100644 index 0000000000..acf7f1be79 --- /dev/null +++ b/internal/ui/model/subagent_rewrite.go @@ -0,0 +1,49 @@ +package model + +import ( + "strings" + + "github.com/charmbracelet/crush/internal/ui/completions" + "github.com/charmbracelet/crush/internal/workspace" +) + +// buildSubagentCaches projects the workspace's active subagents into the two +// shapes the UI consumes: completion items (for the @-mention picker) and a +// name set (for sendMessage rewriting). Iteration order matches the input so +// completion ordering is deterministic. +func buildSubagentCaches(active []workspace.SubagentInfo) ([]completions.SubagentCompletionValue, map[string]bool) { + items := make([]completions.SubagentCompletionValue, len(active)) + names := make(map[string]bool, len(active)) + for i, sa := range active { + items[i] = completions.SubagentCompletionValue{Name: sa.Name, Description: sa.Description} + names[sa.Name] = true + } + return items, names +} + +// rebuildSubagentCaches refreshes the @-mention completion caches from the +// workspace's current active subagents. Called when Library discovery changes. +func (m *UI) rebuildSubagentCaches() { + m.activeSubagentItems, m.activeSubagentNames = buildSubagentCaches(m.com.Workspace.ActiveSubagents()) +} + +// rewriteSubagentPrompt detects the pattern `@name rest` at the start of +// content and rewrites it to a delegation instruction when name is a known +// active subagent. Returns content unchanged if the pattern doesn't match. +func rewriteSubagentPrompt(content string, activeNames map[string]bool) string { + if !strings.HasPrefix(content, "@") { + return content + } + name, prompt, ok := strings.Cut(content[1:], " ") + if !ok { + return content + } + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return content + } + if !activeNames[name] { + return content + } + return `Use the agent tool with subagent_type="` + name + `" to handle this request: ` + prompt +} diff --git a/internal/ui/model/subagent_rewrite_test.go b/internal/ui/model/subagent_rewrite_test.go new file mode 100644 index 0000000000..58818b87c3 --- /dev/null +++ b/internal/ui/model/subagent_rewrite_test.go @@ -0,0 +1,193 @@ +package model + +import ( + "testing" + + "github.com/charmbracelet/crush/internal/ui/common" + "github.com/charmbracelet/crush/internal/workspace" + "github.com/stretchr/testify/require" +) + +// activeSubagentsWorkspace stubs ActiveSubagents for rebuildSubagentCaches. +type activeSubagentsWorkspace struct { + workspace.Workspace + active []workspace.SubagentInfo +} + +func (w *activeSubagentsWorkspace) ActiveSubagents() []workspace.SubagentInfo { return w.active } + +// TestRebuildSubagentCaches verifies the handler invoked on a subagents.Event +// rebuilds the @-mention caches from the workspace's current active list, so a +// removed subagent stops being offered without a restart. +func TestRebuildSubagentCaches(t *testing.T) { + t.Parallel() + + ws := &activeSubagentsWorkspace{active: []workspace.SubagentInfo{{Name: "alpha"}, {Name: "beta"}}} + m := &UI{com: &common.Common{Workspace: ws}} + + m.rebuildSubagentCaches() + require.True(t, m.activeSubagentNames["alpha"]) + require.True(t, m.activeSubagentNames["beta"]) + require.Len(t, m.activeSubagentItems, 2) + + // Discovery change drops beta — cache must reflect it on rebuild. + ws.active = []workspace.SubagentInfo{{Name: "alpha"}} + m.rebuildSubagentCaches() + require.True(t, m.activeSubagentNames["alpha"]) + require.False(t, m.activeSubagentNames["beta"], "removed subagent must drop from cache") + require.Len(t, m.activeSubagentItems, 1) +} + +func TestBuildSubagentCaches(t *testing.T) { + t.Parallel() + + t.Run("empty_input", func(t *testing.T) { + t.Parallel() + items, names := buildSubagentCaches(nil) + require.Empty(t, items) + require.Empty(t, names) + require.NotNil(t, names, "names map must be allocated even when empty") + }) + + t.Run("populates_both_caches", func(t *testing.T) { + t.Parallel() + got, names := buildSubagentCaches([]workspace.SubagentInfo{ + {Name: "code-reviewer", Description: "reviews code"}, + {Name: "tester", Description: "writes tests"}, + }) + + require.Len(t, got, 2) + require.Equal(t, "code-reviewer", got[0].Name) + require.Equal(t, "reviews code", got[0].Description) + require.Equal(t, "tester", got[1].Name) + require.True(t, names["code-reviewer"]) + require.True(t, names["tester"]) + require.False(t, names["missing"]) + }) + + t.Run("preserves_input_order", func(t *testing.T) { + t.Parallel() + got, _ := buildSubagentCaches([]workspace.SubagentInfo{ + {Name: "zeta"}, + {Name: "alpha"}, + {Name: "mu"}, + }) + require.Equal(t, "zeta", got[0].Name) + require.Equal(t, "alpha", got[1].Name) + require.Equal(t, "mu", got[2].Name) + }) +} + +// TestSendMessageRewriteFlow verifies the integration between the cached +// activeSubagentNames produced at UI init and the rewriteSubagentPrompt call +// at the head of sendMessage. Failing this test would mean the caches do not +// line up with the rewrite logic. +func TestSendMessageRewriteFlow(t *testing.T) { + t.Parallel() + + _, names := buildSubagentCaches([]workspace.SubagentInfo{ + {Name: "code-reviewer", Description: "Reviews code."}, + }) + + got := rewriteSubagentPrompt("@code-reviewer review staged", names) + require.Equal(t, `Use the agent tool with subagent_type="code-reviewer" to handle this request: review staged`, got) + + // Unknown name passes through unchanged. + got = rewriteSubagentPrompt("@missing do thing", names) + require.Equal(t, "@missing do thing", got) +} + +// TestRewriteSubagentPrompt covers the pure helper rewriteSubagentPrompt which +// detects an `@name rest` prefix pattern and rewrites it into the canonical +// agent-tool dispatch form when name is in the provided active-names set. +func TestRewriteSubagentPrompt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + activeNames map[string]bool + want string + }{ + { + name: "no_at_prefix", + content: "just a normal message", + activeNames: map[string]bool{"code-reviewer": true}, + want: "just a normal message", + }, + { + name: "at_unknown_name", + content: "@unknown do something", + activeNames: map[string]bool{}, + want: "@unknown do something", + }, + { + name: "at_known_single_word", + content: "@code-reviewer review staged", + activeNames: map[string]bool{"code-reviewer": true}, + want: `Use the agent tool with subagent_type="code-reviewer" to handle this request: review staged`, + }, + { + name: "at_known_multiword_rest", + content: "@tester write tests for the auth module please", + activeNames: map[string]bool{"tester": true}, + want: `Use the agent tool with subagent_type="tester" to handle this request: write tests for the auth module please`, + }, + { + name: "at_name_no_space_after", + content: "@code-reviewer", + activeNames: map[string]bool{"code-reviewer": true}, + want: "@code-reviewer", + }, + { + name: "at_name_only_whitespace_after", + content: "@code-reviewer ", + activeNames: map[string]bool{"code-reviewer": true}, + want: "@code-reviewer ", + }, + { + name: "empty_content", + content: "", + activeNames: map[string]bool{"code-reviewer": true}, + want: "", + }, + { + name: "at_name_with_leading_space", + content: " @code-reviewer review it", + activeNames: map[string]bool{"code-reviewer": true}, + want: " @code-reviewer review it", + }, + { + name: "nil_active_names_does_not_panic", + content: "@code-reviewer review it", + activeNames: nil, + want: "@code-reviewer review it", + }, + { + name: "newline_as_separator_not_supported", + content: "@code-reviewer\nreview this", + activeNames: map[string]bool{"code-reviewer": true}, + want: "@code-reviewer\nreview this", + }, + { + name: "multiple_at_mentions_only_first_rewritten", + content: "@code-reviewer review and @tester test", + activeNames: map[string]bool{"code-reviewer": true, "tester": true}, + want: `Use the agent tool with subagent_type="code-reviewer" to handle this request: review and @tester test`, + }, + { + name: "tab_after_name_not_supported_as_separator", + content: "@code-reviewer\treview this", + activeNames: map[string]bool{"code-reviewer": true}, + want: "@code-reviewer\treview this", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := rewriteSubagentPrompt(tt.content, tt.activeNames) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/ui/model/subagents_panel.go b/internal/ui/model/subagents_panel.go new file mode 100644 index 0000000000..519fb171cf --- /dev/null +++ b/internal/ui/model/subagents_panel.go @@ -0,0 +1,74 @@ +package model + +import ( + "fmt" + + "charm.land/lipgloss/v2" + "github.com/charmbracelet/crush/internal/ui/common" + "github.com/charmbracelet/crush/internal/ui/styles" +) + +type subagentStatusItem struct { + icon string + name string + title string + description string +} + +// subagentsInfo renders the running subagents status section. +func (m *UI) subagentsInfo(width, maxItems int, isSection bool) string { + t := m.com.Styles + + title := t.Resource.Heading.Render("Subagents") + if isSection { + title = common.Section(t, title, width) + } + + if len(m.runningSubagents) == 0 { + list := t.Resource.AdditionalText.Render("None") + return lipgloss.NewStyle().Width(width).Render(fmt.Sprintf("%s\n\n%s", title, list)) + } + + items := make([]subagentStatusItem, 0, len(m.runningSubagents)) + for _, e := range m.runningSubagents { + tokens := e.PromptTokens + e.CompletionTokens + desc := e.Model + if tokens > 0 { + desc = fmt.Sprintf("%s %s", e.Model, t.Resource.AdditionalText.Render(fmt.Sprintf("%d tok", tokens))) + } + items = append(items, subagentStatusItem{ + icon: styles.SubagentDot(e.Color), + name: e.Name, + title: t.Resource.Name.Render(e.Name), + description: desc, + }) + } + + list := subagentsList(t, items, width, maxItems) + return lipgloss.NewStyle().Width(width).Render(fmt.Sprintf("%s\n\n%s", title, list)) +} + +func subagentsList(t *styles.Styles, items []subagentStatusItem, width, maxItems int) string { + if maxItems <= 0 { + return "" + } + + if len(items) > maxItems { + visibleItems := items[:maxItems-1] + remaining := len(items) - (maxItems - 1) + items = append(visibleItems, subagentStatusItem{ + name: "more", + title: t.Resource.AdditionalText.Render(fmt.Sprintf("…and %d more", remaining)), + }) + } + + renderedItems := make([]string, 0, len(items)) + for _, item := range items { + renderedItems = append(renderedItems, common.Status(t, common.StatusOpts{ + Icon: item.icon, + Title: item.title, + Description: item.description, + }, width)) + } + return lipgloss.JoinVertical(lipgloss.Left, renderedItems...) +} diff --git a/internal/ui/model/subagents_panel_test.go b/internal/ui/model/subagents_panel_test.go new file mode 100644 index 0000000000..d3d212ff3f --- /dev/null +++ b/internal/ui/model/subagents_panel_test.go @@ -0,0 +1,88 @@ +package model + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/charmbracelet/crush/internal/ui/common" + uistyles "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/charmbracelet/crush/internal/workspace" +) + +func TestSubagentsInfo_Empty(t *testing.T) { + t.Parallel() + + st := uistyles.CharmtonePantera() + m := &UI{ + com: &common.Common{Styles: &st}, + runningSubagents: nil, + } + + got := m.subagentsInfo(40, 10, false) + + require.NotEmpty(t, got) + require.Contains(t, stripANSI(got), "Subagents") + require.Contains(t, stripANSI(got), "None") +} + +func TestSubagentsInfo_SingleEntry(t *testing.T) { + t.Parallel() + + st := uistyles.CharmtonePantera() + m := &UI{ + com: &common.Common{Styles: &st}, + runningSubagents: []workspace.RunningSubagentInfo{ + {Name: "test-agent", Color: "blue"}, + }, + } + + got := m.subagentsInfo(40, 10, false) + + require.Contains(t, stripANSI(got), "test-agent") + dot := uistyles.SubagentDot("blue") + require.Contains(t, got, dot) +} + +func TestSubagentsInfo_MultipleEntries(t *testing.T) { + t.Parallel() + + st := uistyles.CharmtonePantera() + m := &UI{ + com: &common.Common{Styles: &st}, + runningSubagents: []workspace.RunningSubagentInfo{ + {Name: "alpha-agent", Color: "red"}, + {Name: "beta-agent", Color: "green"}, + {Name: "gamma-agent", Color: "purple"}, + }, + } + + got := m.subagentsInfo(40, 10, false) + plain := stripANSI(got) + + for _, name := range []string{"alpha-agent", "beta-agent", "gamma-agent"} { + require.Contains(t, plain, name, "expected %q in output", name) + } +} + +func TestSubagentsInfo_TruncatesAtMaxItems(t *testing.T) { + t.Parallel() + + st := uistyles.CharmtonePantera() + m := &UI{ + com: &common.Common{Styles: &st}, + runningSubagents: []workspace.RunningSubagentInfo{ + {Name: "agent-one", Color: "red"}, + {Name: "agent-two", Color: "green"}, + {Name: "agent-three", Color: "blue"}, + }, + } + + // maxItems=2 with 3 entries: the list helper reserves one slot for the + // trailer, so visibleItems = items[:1] and remaining = 3-(2-1) = 2, + // producing "…and 2 more" — mirrors the skillsList truncation pattern. + got := m.subagentsInfo(40, 2, false) + plain := stripANSI(got) + + require.Contains(t, plain, "…and 2 more") +} diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 9dfe722796..0a1defab61 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -42,6 +42,7 @@ import ( "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/skills" "github.com/charmbracelet/crush/internal/stringext" + "github.com/charmbracelet/crush/internal/subagents" "github.com/charmbracelet/crush/internal/ui/anim" "github.com/charmbracelet/crush/internal/ui/attachments" "github.com/charmbracelet/crush/internal/ui/chat" @@ -159,6 +160,19 @@ type ( creditsUpdatedMsg struct { credits int } + + // parentTitleMsg is sent when the parent session metadata has been + // fetched: the title for the breadcrumb and this child's subagent color. + parentTitleMsg struct { + title string + color string + } + + // runningSubagentsMsg carries the refreshed running-subagent list, + // resolved off the Update path to keep DB IO out of the message loop. + runningSubagentsMsg struct { + list []workspace.RunningSubagentInfo + } ) // UI represents the main user interface model. @@ -239,6 +253,14 @@ type UI struct { // skills skillStates []*skills.SkillState + // runningSubagents holds the live subagent list for the current session, + // refreshed on each RuntimeEvent. + runningSubagents []workspace.RunningSubagentInfo + + // Subagents — cached at init, static for session lifetime. + activeSubagentItems []completions.SubagentCompletionValue + activeSubagentNames map[string]bool + // sidebarLogo keeps a cached version of the sidebar sidebarLogo. sidebarLogo string @@ -282,6 +304,12 @@ type UI struct { index int draft string } + + // parentTitle holds the resolved parent session title for the breadcrumb. + // subagentColor holds this child session's subagent color, looked up from + // the runtime when the session loads. + parentTitle string + subagentColor string } // New creates a new instance of the [UI] model. @@ -307,7 +335,6 @@ func New(com *common.Common, initialSessionID string, continueLast bool) *UI { com.Styles.Completions.Focused, com.Styles.Completions.Match, ) - todoSpinner := spinner.New( spinner.WithSpinner(spinner.MiniDot), spinner.WithStyle(com.Styles.Pills.TodoSpinner), @@ -350,6 +377,9 @@ func New(com *common.Common, initialSessionID string, continueLast bool) *UI { skillStates: skills.GetLatestStates(), } + // Cache active subagents once — they are static for the session. + ui.activeSubagentItems, ui.activeSubagentNames = buildSubagentCaches(com.Workspace.ActiveSubagents()) + status := NewStatus(com, ui) ui.setEditorPrompt(com.Workspace.PermissionSkipRequests()) @@ -588,6 +618,8 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.setState(uiChat, m.focus) m.session = msg.session m.sessionFiles = msg.files + m.parentTitle = "" + m.subagentColor = "" cmds = append(cmds, m.startLSPs(msg.lspFilePaths())) msgs, err := m.com.Workspace.ListMessages(context.Background(), m.session.ID) if err != nil { @@ -611,8 +643,15 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Reload prompt history for the new session. m.historyReset() cmds = append(cmds, m.loadPromptHistory()) + if m.session.ParentSessionID != "" { + cmds = append(cmds, m.fetchParentMeta(m.session.ParentSessionID, m.session.ID)) + } m.updateLayoutAndSize() + case parentTitleMsg: + m.parentTitle = msg.title + m.subagentColor = msg.color + case sessionFilesUpdatesMsg: m.sessionFiles = msg.sessionFiles var paths []string @@ -714,6 +753,24 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.lspStates = app.GetLSPStates() case pubsub.Event[skills.Event]: m.skillStates = msg.Payload.States + case pubsub.Event[subagents.RuntimeEvent]: + switch { + case m.session == nil: + m.runningSubagents = nil + case msg.Payload.ParentSessionID == m.session.ID: + // Only the current session's children populate the panel; ignore + // events for other parents to avoid spurious DB refreshes. + cmds = append(cmds, m.refreshRunningSubagents(m.session.ID)) + } + if f := msg.Payload.Finished; f != nil && m.session != nil && f.ParentSessionID == m.session.ID { + cmds = append(cmds, util.ReportInfo(fmt.Sprintf("Subagent %s %s", f.Name, f.Status))) + } + case runningSubagentsMsg: + m.runningSubagents = msg.list + case pubsub.Event[subagents.Event]: + // Library discovery changed (e.g. a delete) — rebuild the @-mention + // caches so removed subagents stop being offered without a restart. + m.rebuildSubagentCaches() case pubsub.Event[mcp.Event]: switch msg.Payload.Type { case mcp.EventStateChanged: @@ -964,7 +1021,7 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.status.ClearInfoMsg() case completions.CompletionItemsLoadedMsg: if m.completionsOpen { - m.completions.SetItems(msg.Files, msg.Resources) + m.completions.SetItems(msg.Files, msg.Resources, msg.Subagents) } case uv.KittyGraphicsEvent: if !bytes.HasPrefix(msg.Payload, []byte("OK")) { @@ -1404,6 +1461,11 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { m.dialog.CloseDialog(dialog.SessionsID) cmds = append(cmds, m.loadSession(msg.Session.ID)) + // Subagents dialog messages. + case dialog.ActionLoadSubagentSession: + m.dialog.CloseDialog(dialog.SubagentsID) + cmds = append(cmds, m.loadSession(msg.SessionID)) + // Open dialog message. case dialog.ActionOpenDialog: m.dialog.CloseDialog(dialog.CommandsID) @@ -1878,6 +1940,14 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { } cmds = append(cmds, util.ReportInfo("Yolo mode "+status)) return true + case key.Matches(msg, m.keyMap.ParentSession): + if m.session != nil && m.session.ParentSessionID != "" { + cmds = append(cmds, m.loadSession(m.session.ParentSessionID)) + return true + } + case key.Matches(msg, m.keyMap.Subagents): + m.openSubagentsDialog() + return true } return false } @@ -1929,6 +1999,11 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { if !msg.KeepOpen { m.closeCompletions() } + case completions.SelectionMsg[completions.SubagentCompletionValue]: + cmds = append(cmds, m.insertSubagentCompletion(msg.Value.Name)) + if !msg.KeepOpen { + m.closeCompletions() + } case completions.ClosedMsg: m.completionsOpen = false } @@ -2055,7 +2130,7 @@ func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { m.completionsStartIndex = curIdx m.completionsPositionStart = m.completionsPosition() depth, limit := m.com.Config().Options.TUI.Completions.Limits() - cmds = append(cmds, m.completions.Open(depth, limit)) + cmds = append(cmds, m.completions.Open(depth, limit, m.activeSubagentItems)) } } @@ -2396,6 +2471,7 @@ func (m *UI) ShortHelp() []key.Binding { tab, commands, k.Models, + k.Subagents, ) switch m.focus { @@ -2425,6 +2501,7 @@ func (m *UI) ShortHelp() []key.Binding { binds, commands, k.Models, + k.Subagents, k.Editor.Newline, ) } @@ -2482,6 +2559,7 @@ func (m *UI) FullHelp() [][]key.Binding { tab, commands, k.Models, + k.Subagents, k.Sessions, k.ToggleYolo, ) @@ -2544,6 +2622,7 @@ func (m *UI) FullHelp() [][]key.Binding { []key.Binding{ commands, k.Models, + k.Subagents, k.Sessions, k.ToggleYolo, }, @@ -3058,6 +3137,15 @@ func (m *UI) insertFileCompletion(path string) tea.Cmd { return tea.Batch(heightCmd, fileCmd) } +// insertSubagentCompletion inserts @name into the textarea, replacing the @query. +func (m *UI) insertSubagentCompletion(name string) tea.Cmd { + prevHeight := m.textarea.Height() + if !m.insertCompletionText("@" + name) { + return nil + } + return m.handleTextareaHeightChange(prevHeight) +} + // insertMCPResourceCompletion inserts the selected resource into the textarea, // replacing the @query, and adds the resource as an attachment. func (m *UI) insertMCPResourceCompletion(item completions.ResourceCompletionValue) tea.Cmd { @@ -3253,6 +3341,8 @@ func (m *UI) attachSkill(skillID, name string) tea.Cmd { // sendMessage sends a message with the given content and attachments. func (m *UI) sendMessage(content string, attachments ...message.Attachment) tea.Cmd { + content = rewriteSubagentPrompt(content, m.activeSubagentNames) + if !m.com.Workspace.AgentIsReady() { return util.ReportError(fmt.Errorf("coder agent is not initialized")) } @@ -3351,6 +3441,10 @@ func (m *UI) openDialog(id string) tea.Cmd { if cmd := m.openSessionsDialog(); cmd != nil { cmds = append(cmds, cmd) } + case dialog.SubagentsID: + if cmd := m.openSubagentsDialog(); cmd != nil { + cmds = append(cmds, cmd) + } case dialog.ModelsID: if cmd := m.openModelsDialog(); cmd != nil { cmds = append(cmds, cmd) @@ -3492,6 +3586,24 @@ func (m *UI) openSessionsDialog() tea.Cmd { return nil } +// openSubagentsDialog opens the subagents dialog. If the dialog is already +// open, it brings it to the front. Subagent surfaces are local-mode only: in +// client/server mode the ClientWorkspace stubs return empty, so the dialog +// opens with no running or library entries. +func (m *UI) openSubagentsDialog() tea.Cmd { + if m.dialog.ContainsDialog(dialog.SubagentsID) { + m.dialog.BringToFront(dialog.SubagentsID) + return nil + } + sessionID := "" + if m.session != nil { + sessionID = m.session.ID + } + d := dialog.NewSubagents(m.com, sessionID) + m.dialog.OpenDialog(d) + return nil +} + // openFilesDialog opens the file picker dialog. func (m *UI) openFilesDialog() tea.Cmd { if m.dialog.ContainsDialog(dialog.FilePickerID) { @@ -3853,14 +3965,24 @@ func (m *UI) drawSessionDetails(scr uv.Screen, area uv.Rectangle) { remainingHeight := height - lipgloss.Height(detailsHeader) - lipgloss.Height(version) const maxSectionWidth = 50 - sectionWidth := max(1, min(maxSectionWidth, width/4-2)) // account for spacing between sections - maxItemsPerSection := remainingHeight - 3 // Account for section title and spacing + numSections := 4 + if len(m.runningSubagents) > 0 { + numSections = 5 + } + sectionWidth := max(1, min(maxSectionWidth, width/numSections-2)) // account for spacing between sections + maxItemsPerSection := remainingHeight - 3 // Account for section title and spacing lspSection := m.lspInfo(sectionWidth, maxItemsPerSection, false) mcpSection := m.mcpInfo(sectionWidth, maxItemsPerSection, false) skillsSection := m.skillsInfo(sectionWidth, maxItemsPerSection, false) filesSection := m.filesInfo(m.com.Workspace.WorkingDir(), sectionWidth, maxItemsPerSection, false) - sections := lipgloss.JoinHorizontal(lipgloss.Top, filesSection, " ", lspSection, " ", mcpSection, " ", skillsSection) + var sections string + if len(m.runningSubagents) > 0 { + subagentsSection := m.subagentsInfo(sectionWidth, maxItemsPerSection, false) + sections = lipgloss.JoinHorizontal(lipgloss.Top, filesSection, " ", lspSection, " ", mcpSection, " ", skillsSection, " ", subagentsSection) + } else { + sections = lipgloss.JoinHorizontal(lipgloss.Top, filesSection, " ", lspSection, " ", mcpSection, " ", skillsSection) + } uv.NewStyledString( s.CompactDetails.View. Width(area.Dx()). diff --git a/internal/ui/styles/subagent_palette.go b/internal/ui/styles/subagent_palette.go new file mode 100644 index 0000000000..8f74d78cd9 --- /dev/null +++ b/internal/ui/styles/subagent_palette.go @@ -0,0 +1,35 @@ +package styles + +import ( + "charm.land/lipgloss/v2" + "github.com/charmbracelet/x/exp/charmtone" +) + +// SubagentDot returns a colored "●" string for the given palette color name. +// Recognized names are the eight subagent palette colors (red, orange, yellow, +// green, cyan, blue, purple, pink). Unrecognized names return a plain "●" with +// no styling applied. +func SubagentDot(color string) string { + var fg charmtone.Key + switch color { + case "red": + fg = charmtone.Cherry + case "orange": + fg = charmtone.Tang + case "yellow": + fg = charmtone.Citron + case "green": + fg = charmtone.Julep + case "cyan": + fg = charmtone.Guppy + case "blue": + fg = charmtone.Sapphire + case "purple": + fg = charmtone.Mauve + case "pink": + fg = charmtone.Flamingo + default: + return "●" + } + return lipgloss.NewStyle().Foreground(fg).SetString("●").String() +} diff --git a/internal/workspace/active_subagents_test.go b/internal/workspace/active_subagents_test.go new file mode 100644 index 0000000000..4bdec1a0ef --- /dev/null +++ b/internal/workspace/active_subagents_test.go @@ -0,0 +1,52 @@ +package workspace + +import ( + "testing" + + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/subagents" + "github.com/stretchr/testify/require" +) + +func TestAppWorkspace_ActiveSubagents_NilManagerReturnsNil(t *testing.T) { + t.Parallel() + + w := &AppWorkspace{app: &app.App{}} + require.Nil(t, w.ActiveSubagents()) +} + +func TestAppWorkspace_ActiveSubagents_MapsManagerOutput(t *testing.T) { + t.Parallel() + + mgr := subagents.NewManager(nil, []*subagents.Subagent{ + {Name: "code-reviewer", Description: "Reviews code."}, + {Name: "tester", Description: "Writes tests."}, + }, nil) + t.Cleanup(mgr.Shutdown) + + w := &AppWorkspace{app: &app.App{Subagents: mgr}} + + got := w.ActiveSubagents() + require.Len(t, got, 2) + require.Equal(t, "code-reviewer", got[0].Name) + require.Equal(t, "Reviews code.", got[0].Description) + require.Equal(t, "tester", got[1].Name) + require.Equal(t, "Writes tests.", got[1].Description) +} + +func TestAppWorkspace_ActiveSubagents_EmptyManagerReturnsEmpty(t *testing.T) { + t.Parallel() + + mgr := subagents.NewManager(nil, nil, nil) + t.Cleanup(mgr.Shutdown) + + w := &AppWorkspace{app: &app.App{Subagents: mgr}} + require.Empty(t, w.ActiveSubagents()) +} + +func TestClientWorkspace_ActiveSubagents_AlwaysNil(t *testing.T) { + t.Parallel() + + w := &ClientWorkspace{} + require.Nil(t, w.ActiveSubagents()) +} diff --git a/internal/workspace/app_workspace.go b/internal/workspace/app_workspace.go index c35a9f59fe..980c4f3d64 100644 --- a/internal/workspace/app_workspace.go +++ b/internal/workspace/app_workspace.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "os" + "strings" "time" tea "charm.land/bubbletea/v2" @@ -17,8 +19,10 @@ import ( "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/skills" + "github.com/charmbracelet/crush/internal/subagents" ) // AppWorkspace implements the Workspace interface by delegating @@ -322,6 +326,206 @@ func (w *AppWorkspace) ReadSkill(_ context.Context, skillID string) ([]byte, ski return skills.ReadContent(mgr.ActiveSkills(), mgr.ResolvedPaths(), mgr.WorkingDir(), skillID) } +// ActiveSubagents returns the workspace's post-filter list of active subagents +// projected to the frontend-facing SubagentInfo shape. Returns nil when the +// workspace has no Subagents manager configured. +func (w *AppWorkspace) ActiveSubagents() []SubagentInfo { + mgr := w.app.Subagents + if mgr == nil { + return nil + } + active := mgr.ActiveSubagents() + result := make([]SubagentInfo, len(active)) + for i, sa := range active { + result[i] = SubagentInfo{Name: sa.Name, Description: sa.Description} + } + return result +} + +// RunningSubagents returns info about all subagent sessions currently running +// under the given parentSessionID, enriched with token counts from the session +// service where available. Returns nil when SubagentRuntime is nil. +func (w *AppWorkspace) RunningSubagents(parentSessionID string) []RunningSubagentInfo { + if w.app.SubagentRuntime == nil { + return nil + } + entries := w.app.SubagentRuntime.List(parentSessionID) + if len(entries) == 0 { + return nil + } + result := make([]RunningSubagentInfo, len(entries)) + for i, e := range entries { + info := RunningSubagentInfo{ + ChildSessionID: e.ChildSessionID, + ParentSessionID: e.ParentSessionID, + Name: e.Name, + Color: e.Color, + Model: e.Model, + Status: e.Status, + StartedAt: e.StartedAt, + } + if w.app.Sessions != nil { + if sess, err := w.app.Sessions.Get(context.Background(), e.ChildSessionID); err == nil { + info.PromptTokens = sess.PromptTokens + info.CompletionTokens = sess.CompletionTokens + } + } + result[i] = info + } + return result +} + +// SubscribeSubagentRuntime returns a channel of RuntimeEvents from the +// SubagentRuntime. Returns a closed channel when SubagentRuntime is nil. +func (w *AppWorkspace) SubscribeSubagentRuntime(ctx context.Context) <-chan pubsub.Event[subagents.RuntimeEvent] { + if w.app.SubagentRuntime == nil { + ch := make(chan pubsub.Event[subagents.RuntimeEvent]) + close(ch) + return ch + } + return w.app.SubagentRuntime.Subscribe(ctx) +} + +// CancelSubagent cancels the subagent session with the given childSessionID. +// It is a no-op when AgentCoordinator is nil. +func (w *AppWorkspace) CancelSubagent(childSessionID string) { + if w.app.AgentCoordinator == nil { + return + } + w.app.AgentCoordinator.Cancel(childSessionID) +} + +// AllSubagents returns all discovered subagent definitions projected to the +// frontend-facing SubagentDefInfo shape, with scope detection relative to the +// workspace working directory. Returns nil when the Subagents manager is nil. +func (w *AppWorkspace) AllSubagents() []SubagentDefInfo { + mgr := w.app.Subagents + if mgr == nil { + return nil + } + all := mgr.AllSubagents() + cfg := w.store.Config() + workingDir := w.store.WorkingDir() + + var disabledSubagents []string + if cfg.Options != nil { + disabledSubagents = cfg.Options.DisabledSubagents + } + disabledSet := make(map[string]bool, len(disabledSubagents)) + for _, name := range disabledSubagents { + disabledSet[name] = true + } + + result := make([]SubagentDefInfo, len(all)) + for i, s := range all { + scope := "user" + if s.FilePath == "" { + scope = "builtin" + } else if workingDir != "" && (strings.HasPrefix(s.FilePath, workingDir+"/") || s.FilePath == workingDir) { + scope = "project" + } + result[i] = SubagentDefInfo{ + Name: s.Name, + Description: s.Description, + Color: s.ResolvedColor(), + FilePath: s.FilePath, + Scope: scope, + Disabled: disabledSet[s.Name], + } + } + return result +} + +// DeleteUserSubagent removes a user-scoped subagent by name. It returns an +// error if the subagent is not found or is not user-scoped. On success it +// deletes the file from disk and reloads the Subagents manager. +func (w *AppWorkspace) DeleteUserSubagent(name string) error { + var target *SubagentDefInfo + for _, info := range w.AllSubagents() { + if info.Name == name { + cp := info + target = &cp + break + } + } + if target == nil { + return fmt.Errorf("subagent %q not found", name) + } + if target.Scope != "user" { + return fmt.Errorf("subagent %q is not user-scoped and cannot be deleted", name) + } + if err := os.Remove(target.FilePath); err != nil { + return err + } + w.reloadSubagents() + return nil +} + +// SetSubagentDisabled enables or disables a subagent by name, persisting the +// change to options.disabled_subagents at project scope and reloading +// discovery. A disabled subagent is filtered out of the active set, which is +// what the dispatcher enum, dispatch lookup, @-mention completions, and the +// @-rewrite all derive from — so it can be neither auto-selected by the main +// agent nor invoked manually. +func (w *AppWorkspace) SetSubagentDisabled(name string, disabled bool) error { + var current []string + if cfg := w.store.Config(); cfg.Options != nil { + current = cfg.Options.DisabledSubagents + } + next := addOrRemove(current, name, disabled) + if err := w.store.SetConfigField(config.ScopeWorkspace, "options.disabled_subagents", next); err != nil { + return err + } + w.reloadSubagents() + return nil +} + +// reloadSubagents re-runs discovery from the current config and swaps the +// Manager's snapshot, publishing a discovery event. Model ids are validated +// against the config (matching startup) so an invalid model stays rejected. +func (w *AppWorkspace) reloadSubagents() { + cfg := w.store.Config() + var subagentsPaths, disabledSubagents []string + if cfg.Options != nil { + subagentsPaths = cfg.Options.SubagentsPaths + disabledSubagents = cfg.Options.DisabledSubagents + } + all, active, states := subagents.DiscoverFromConfig(subagents.DiscoveryConfig{ + SubagentsPaths: subagentsPaths, + DisabledSubagents: disabledSubagents, + + // Match startup discovery (cmd/root.go, backend.go): validate model + // ids so a subagent with an invalid model stays rejected after reload. + IsKnownModel: cfg.IsKnownModel, + }) + w.app.Subagents.Reload(all, active, states) +} + +// addOrRemove returns list with name added (when add) or all occurrences +// removed (when !add). The result is a fresh slice; order is otherwise stable. +func addOrRemove(list []string, name string, add bool) []string { + next := make([]string, 0, len(list)+1) + for _, n := range list { + if n != name { + next = append(next, n) + } + } + if add { + next = append(next, name) + } + return next +} + +// SessionTokens returns the prompt and completion token counts for the given +// session. It delegates to the session service and propagates any error. +func (w *AppWorkspace) SessionTokens(ctx context.Context, sessionID string) (prompt, completion int64, err error) { + sess, err := w.app.Sessions.Get(ctx, sessionID) + if err != nil { + return 0, 0, err + } + return sess.PromptTokens, sess.CompletionTokens, nil +} + // -- MCP operations -- func (w *AppWorkspace) MCPGetStates() map[string]mcptools.ClientInfo { diff --git a/internal/workspace/client_workspace.go b/internal/workspace/client_workspace.go index 09ff57c612..2c7346982f 100644 --- a/internal/workspace/client_workspace.go +++ b/internal/workspace/client_workspace.go @@ -24,6 +24,7 @@ import ( "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/skills" + "github.com/charmbracelet/crush/internal/subagents" "github.com/charmbracelet/x/powernap/pkg/lsp/protocol" ) @@ -530,6 +531,53 @@ func (w *ClientWorkspace) ReadSkill(ctx context.Context, skillID string) ([]byte }, nil } +// -- Subagents (local-mode only) -- +// +// All subagent surfaces are unimplemented over RPC: discovery, the running +// runtime, cancellation, and deletion are server-side concerns the client does +// not expose today. These stubs return empty/no-op, so in client/server mode +// the Subagents dialog opens with no entries. + +// ActiveSubagents returns nil in client mode. +func (w *ClientWorkspace) ActiveSubagents() []SubagentInfo { + return nil +} + +// RunningSubagents returns nil in client mode. +func (w *ClientWorkspace) RunningSubagents(_ string) []RunningSubagentInfo { + return nil +} + +// SubscribeSubagentRuntime returns a closed channel in client mode. +func (w *ClientWorkspace) SubscribeSubagentRuntime(_ context.Context) <-chan pubsub.Event[subagents.RuntimeEvent] { + ch := make(chan pubsub.Event[subagents.RuntimeEvent]) + close(ch) + return ch +} + +// CancelSubagent is a no-op in client mode. +func (w *ClientWorkspace) CancelSubagent(_ string) {} + +// AllSubagents returns nil in client mode. +func (w *ClientWorkspace) AllSubagents() []SubagentDefInfo { + return nil +} + +// DeleteUserSubagent returns an error in client mode. +func (w *ClientWorkspace) DeleteUserSubagent(_ string) error { + return nil +} + +// SetSubagentDisabled is a no-op in client mode. +func (w *ClientWorkspace) SetSubagentDisabled(_ string, _ bool) error { + return nil +} + +// SessionTokens returns zero token counts in client mode. +func (w *ClientWorkspace) SessionTokens(_ context.Context, _ string) (int64, int64, error) { + return 0, 0, nil +} + // -- MCP operations -- func (w *ClientWorkspace) MCPGetStates() map[string]mcp.ClientInfo { diff --git a/internal/workspace/running_subagents_test.go b/internal/workspace/running_subagents_test.go new file mode 100644 index 0000000000..d065c3fe94 --- /dev/null +++ b/internal/workspace/running_subagents_test.go @@ -0,0 +1,485 @@ +package workspace + +import ( + "context" + "database/sql" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/charmbracelet/crush/internal/session" + "github.com/charmbracelet/crush/internal/subagents" +) + +// -- minimal session.Service stub for token-enrichment tests -- + +type stubSessionService struct { + sessions map[string]session.Session +} + +func (s *stubSessionService) Subscribe(context.Context) <-chan pubsub.Event[session.Session] { + return make(chan pubsub.Event[session.Session]) +} + +func (s *stubSessionService) Create(_ context.Context, title string) (session.Session, error) { + return session.Session{ID: "new", Title: title}, nil +} + +func (s *stubSessionService) CreateTitleSession(context.Context, string) (session.Session, error) { + return session.Session{}, nil +} + +func (s *stubSessionService) CreateTaskSession(context.Context, string, string, string) (session.Session, error) { + return session.Session{}, nil +} + +func (s *stubSessionService) Get(_ context.Context, id string) (session.Session, error) { + if sess, ok := s.sessions[id]; ok { + return sess, nil + } + return session.Session{}, sql.ErrNoRows +} + +func (s *stubSessionService) GetLast(context.Context) (session.Session, error) { + return session.Session{}, sql.ErrNoRows +} + +func (s *stubSessionService) List(context.Context) ([]session.Session, error) { + return nil, nil +} + +func (s *stubSessionService) Save(_ context.Context, sess session.Session) (session.Session, error) { + return sess, nil +} + +func (s *stubSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error { + return nil +} + +func (s *stubSessionService) Rename(context.Context, string, string) error { return nil } + +func (s *stubSessionService) Delete(context.Context, string) error { return nil } + +func (s *stubSessionService) CreateAgentToolSessionID(messageID, toolCallID string) string { + return fmt.Sprintf("%s$$%s", messageID, toolCallID) +} + +func (s *stubSessionService) ParseAgentToolSessionID(sessionID string) (string, string, bool) { + parts := strings.Split(sessionID, "$$") + if len(parts) != 2 { + return "", "", false + } + return parts[0], parts[1], true +} + +func (s *stubSessionService) IsAgentToolSession(sessionID string) bool { + _, _, ok := s.ParseAgentToolSessionID(sessionID) + return ok +} + +// newStoreForWorkDir returns a ConfigStore whose WorkingDir() reports workDir. +func newStoreForWorkDir(workDir string) *config.ConfigStore { + return config.NewTestStoreWithWorkingDir(&config.Config{}, workDir) +} + +// TestAppWorkspace_RunningSubagents_Empty verifies that a nil SubagentRuntime +// returns a nil slice without panicking. +func TestAppWorkspace_RunningSubagents_Empty(t *testing.T) { + t.Parallel() + + w := &AppWorkspace{ + app: &app.App{SubagentRuntime: nil}, + store: config.NewTestStore(&config.Config{}), + } + + got := w.RunningSubagents("parent-1") + require.Nil(t, got) +} + +// TestAppWorkspace_RunningSubagents_WithEntries verifies that entries registered +// on the Runtime are mapped to RunningSubagentInfo with the correct fields. +func TestAppWorkspace_RunningSubagents_WithEntries(t *testing.T) { + t.Parallel() + + rt := subagents.NewRuntime() + t.Cleanup(rt.Shutdown) + + rt.Register("parent-1", "child-A", "agent-alpha", "blue", "") + rt.Register("parent-1", "child-B", "agent-beta", "red", "") + + w := &AppWorkspace{ + app: &app.App{ + SubagentRuntime: rt, + Sessions: &stubSessionService{sessions: map[string]session.Session{}}, + }, + store: config.NewTestStore(&config.Config{}), + } + + got := w.RunningSubagents("parent-1") + require.Len(t, got, 2) + + byChild := map[string]RunningSubagentInfo{} + for _, info := range got { + byChild[info.ChildSessionID] = info + } + + a := byChild["child-A"] + require.Equal(t, "parent-1", a.ParentSessionID) + require.Equal(t, "agent-alpha", a.Name) + require.Equal(t, "blue", a.Color) + require.Equal(t, "running", a.Status) + require.False(t, a.StartedAt.IsZero()) + + b := byChild["child-B"] + require.Equal(t, "agent-beta", b.Name) + require.Equal(t, "red", b.Color) +} + +// TestAppWorkspace_RunningSubagents_TokenEnrichment verifies that when a child +// session exists, its PromptTokens and CompletionTokens are included in the +// returned RunningSubagentInfo. +func TestAppWorkspace_RunningSubagents_TokenEnrichment(t *testing.T) { + t.Parallel() + + rt := subagents.NewRuntime() + t.Cleanup(rt.Shutdown) + + rt.Register("parent-1", "child-tok", "agent-tok", "green", "") + + sessions := &stubSessionService{ + sessions: map[string]session.Session{ + "child-tok": { + ID: "child-tok", + PromptTokens: 100, + CompletionTokens: 200, + }, + }, + } + + w := &AppWorkspace{ + app: &app.App{ + SubagentRuntime: rt, + Sessions: sessions, + }, + store: config.NewTestStore(&config.Config{}), + } + + got := w.RunningSubagents("parent-1") + require.Len(t, got, 1) + require.Equal(t, int64(100), got[0].PromptTokens) + require.Equal(t, int64(200), got[0].CompletionTokens) +} + +// TestAppWorkspace_SubscribeSubagentRuntime_NilRuntime verifies that a nil +// SubagentRuntime returns a closed channel without panicking. +func TestAppWorkspace_SubscribeSubagentRuntime_NilRuntime(t *testing.T) { + t.Parallel() + + w := &AppWorkspace{ + app: &app.App{SubagentRuntime: nil}, + store: config.NewTestStore(&config.Config{}), + } + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + ch := w.SubscribeSubagentRuntime(ctx) + require.NotNil(t, ch) + + select { + case _, ok := <-ch: + require.False(t, ok, "channel must be closed when SubagentRuntime is nil") + default: + t.Fatal("channel was not immediately closed for nil SubagentRuntime") + } +} + +// TestAppWorkspace_CancelSubagent_NilCoordinator verifies that calling +// CancelSubagent with a nil AgentCoordinator does not panic. +func TestAppWorkspace_CancelSubagent_NilCoordinator(t *testing.T) { + t.Parallel() + + w := &AppWorkspace{ + app: &app.App{AgentCoordinator: nil}, + store: config.NewTestStore(&config.Config{}), + } + + require.NotPanics(t, func() { + w.CancelSubagent("child-session-id") + }) +} + +// TestAppWorkspace_AllSubagents_NilManager verifies that a nil Subagents +// manager returns nil without panicking. +func TestAppWorkspace_AllSubagents_NilManager(t *testing.T) { + t.Parallel() + + w := &AppWorkspace{ + app: &app.App{Subagents: nil}, + store: config.NewTestStore(&config.Config{}), + } + + got := w.AllSubagents() + require.Nil(t, got) +} + +// TestAppWorkspace_AllSubagents_ScopeDetection verifies that the Scope field on +// returned SubagentDefInfo is set to "project" for agents whose file path is +// under the workspace working directory, "user" for agents outside, and +// "builtin" for agents with an empty FilePath. +func TestAppWorkspace_AllSubagents_ScopeDetection(t *testing.T) { + t.Parallel() + + workDir := t.TempDir() + + projectFile := filepath.Join(workDir, ".crush", "agents", "proj-agent.md") + require.NoError(t, os.MkdirAll(filepath.Dir(projectFile), 0o755)) + require.NoError(t, os.WriteFile( + projectFile, + []byte("---\nname: proj-agent\ndescription: Project agent.\n---\n\nBody.\n"), + 0o644, + )) + + userDir := t.TempDir() + userFile := filepath.Join(userDir, "user-agent.md") + require.NoError(t, os.WriteFile( + userFile, + []byte("---\nname: user-agent\ndescription: User agent.\n---\n\nBody.\n"), + 0o644, + )) + + projAgent := &subagents.Subagent{Name: "proj-agent", Description: "Project agent.", FilePath: projectFile} + userAgent := &subagents.Subagent{Name: "user-agent", Description: "User agent.", FilePath: userFile} + builtinAgent := &subagents.Subagent{Name: "builtin-agent", Description: "Built-in agent.", FilePath: ""} + + mgr := subagents.NewManager( + []*subagents.Subagent{projAgent, userAgent, builtinAgent}, + []*subagents.Subagent{projAgent, userAgent, builtinAgent}, + nil, + ) + t.Cleanup(mgr.Shutdown) + + w := &AppWorkspace{ + app: &app.App{Subagents: mgr}, + store: newStoreForWorkDir(workDir), + } + + got := w.AllSubagents() + require.Len(t, got, 3) + + byName := map[string]SubagentDefInfo{} + for _, info := range got { + byName[info.Name] = info + } + + require.Equal(t, "project", byName["proj-agent"].Scope) + require.Equal(t, "user", byName["user-agent"].Scope) + require.Equal(t, "builtin", byName["builtin-agent"].Scope) +} + +// TestAppWorkspace_DeleteUserSubagent_NotFound verifies that deleting a +// subagent by a name that doesn't exist returns an error. +func TestAppWorkspace_DeleteUserSubagent_NotFound(t *testing.T) { + t.Parallel() + + mgr := subagents.NewManager(nil, nil, nil) + t.Cleanup(mgr.Shutdown) + + w := &AppWorkspace{ + app: &app.App{Subagents: mgr}, + store: config.NewTestStore(&config.Config{}), + } + + err := w.DeleteUserSubagent("nonexistent-agent") + require.Error(t, err) +} + +// TestAppWorkspace_DeleteUserSubagent_NonUserScope verifies that deleting a +// project-scope subagent (file under workdir) returns an error. +func TestAppWorkspace_DeleteUserSubagent_NonUserScope(t *testing.T) { + t.Parallel() + + workDir := t.TempDir() + + projectFile := filepath.Join(workDir, ".crush", "agents", "proj-agent.md") + require.NoError(t, os.MkdirAll(filepath.Dir(projectFile), 0o755)) + require.NoError(t, os.WriteFile( + projectFile, + []byte("---\nname: proj-agent\ndescription: Project agent.\n---\n\nBody.\n"), + 0o644, + )) + + projAgent := &subagents.Subagent{Name: "proj-agent", Description: "Project agent.", FilePath: projectFile} + mgr := subagents.NewManager( + []*subagents.Subagent{projAgent}, + []*subagents.Subagent{projAgent}, + nil, + ) + t.Cleanup(mgr.Shutdown) + + w := &AppWorkspace{ + app: &app.App{Subagents: mgr}, + store: newStoreForWorkDir(workDir), + } + + err := w.DeleteUserSubagent("proj-agent") + require.Error(t, err) +} + +// TestAppWorkspace_DeleteUserSubagent_Success verifies that deleting a +// user-scope subagent removes the file from disk and the agent no longer +// appears in AllSubagents after the internal Manager is reloaded. +func TestAppWorkspace_DeleteUserSubagent_Success(t *testing.T) { + t.Parallel() + + workDir := t.TempDir() + userDir := t.TempDir() + + userFile := filepath.Join(userDir, "user-agent.md") + require.NoError(t, os.WriteFile( + userFile, + []byte("---\nname: user-agent\ndescription: User agent.\n---\n\nBody.\n"), + 0o644, + )) + + userAgent := &subagents.Subagent{Name: "user-agent", Description: "User agent.", FilePath: userFile} + mgr := subagents.NewManager( + []*subagents.Subagent{userAgent}, + []*subagents.Subagent{userAgent}, + nil, + ) + t.Cleanup(mgr.Shutdown) + + w := &AppWorkspace{ + app: &app.App{Subagents: mgr}, + store: newStoreForWorkDir(workDir), + } + + err := w.DeleteUserSubagent("user-agent") + require.NoError(t, err) + + // File must be gone from disk. + _, statErr := os.Stat(userFile) + require.True(t, os.IsNotExist(statErr), "file must have been deleted from disk") + + // Manager must no longer contain the deleted agent. + for _, info := range w.AllSubagents() { + require.NotEqual(t, "user-agent", info.Name, "deleted agent must not appear in AllSubagents after reload") + } +} + +// TestAppWorkspace_SessionTokens_Found verifies that SessionTokens returns the +// correct token counts for an existing session. +func TestAppWorkspace_SessionTokens_Found(t *testing.T) { + t.Parallel() + + sessions := &stubSessionService{ + sessions: map[string]session.Session{ + "sess-1": { + ID: "sess-1", + PromptTokens: 42, + CompletionTokens: 77, + }, + }, + } + + w := &AppWorkspace{ + app: &app.App{Sessions: sessions}, + store: config.NewTestStore(&config.Config{}), + } + + prompt, completion, err := w.SessionTokens(context.Background(), "sess-1") + require.NoError(t, err) + require.Equal(t, int64(42), prompt) + require.Equal(t, int64(77), completion) +} + +// TestAppWorkspace_SessionTokens_NotFound verifies that SessionTokens returns +// an error when the session does not exist. +func TestAppWorkspace_SessionTokens_NotFound(t *testing.T) { + t.Parallel() + + sessions := &stubSessionService{ + sessions: map[string]session.Session{}, + } + + w := &AppWorkspace{ + app: &app.App{Sessions: sessions}, + store: config.NewTestStore(&config.Config{}), + } + + _, _, err := w.SessionTokens(context.Background(), "does-not-exist") + require.Error(t, err) +} + +// TestAddOrRemove covers the pure list helper backing SetSubagentDisabled. +func TestAddOrRemove(t *testing.T) { + t.Parallel() + + require.Equal(t, []string{"a"}, addOrRemove(nil, "a", true), "add to empty") + require.Equal(t, []string{"a"}, addOrRemove([]string{"a"}, "a", true), "add dedups") + require.Equal(t, []string{}, addOrRemove([]string{"a"}, "a", false), "remove last") + require.Equal(t, []string{"b", "c"}, addOrRemove([]string{"b", "a", "c", "a"}, "a", false), + "remove drops all occurrences, keeps others") + require.Equal(t, []string{"b", "a"}, addOrRemove([]string{"b"}, "a", true), "add appends") +} + +// TestAppWorkspace_DeleteUserSubagent_ReloadValidatesModel verifies that the +// reload after a delete validates model ids (passes cfg.IsKnownModel, not +// nil). A subagent referencing an unknown model must NOT become active after +// the reload — with a nil validator it would be wrongly accepted. +func TestAppWorkspace_DeleteUserSubagent_ReloadValidatesModel(t *testing.T) { + t.Parallel() + + workDir := t.TempDir() + userDir := t.TempDir() + + // One valid user subagent to delete, and one with an unknown model id. + keepFile := filepath.Join(userDir, "keep-agent.md") + require.NoError(t, os.WriteFile( + keepFile, + []byte("---\nname: keep-agent\ndescription: Keep.\n---\n\nBody.\n"), + 0o644, + )) + badFile := filepath.Join(userDir, "bad-agent.md") + require.NoError(t, os.WriteFile( + badFile, + []byte("---\nname: bad-agent\ndescription: Bad model.\nmodel: not-a-real-model-id\n---\n\nBody.\n"), + 0o644, + )) + + keepAgent := &subagents.Subagent{Name: "keep-agent", Description: "Keep.", FilePath: keepFile} + mgr := subagents.NewManager( + []*subagents.Subagent{keepAgent}, + []*subagents.Subagent{keepAgent}, + nil, + ) + t.Cleanup(mgr.Shutdown) + + // Empty (but non-nil) providers => IsKnownModelID returns false for any + // specific id, so bad-agent must be rejected on reload. SubagentsPaths + // drives rediscovery. + cfg := &config.Config{ + Options: &config.Options{SubagentsPaths: []string{userDir}}, + Providers: csync.NewMap[string, config.ProviderConfig](), + } + w := &AppWorkspace{ + app: &app.App{Subagents: mgr}, + store: config.NewTestStoreWithWorkingDir(cfg, workDir), + } + + require.NoError(t, w.DeleteUserSubagent("keep-agent")) + + for _, info := range w.AllSubagents() { + require.NotEqual(t, "bad-agent", info.Name, + "subagent with an unknown model must stay rejected after reload") + } +} diff --git a/internal/workspace/workspace.go b/internal/workspace/workspace.go index 9049b7bc68..32cdc5a73d 100644 --- a/internal/workspace/workspace.go +++ b/internal/workspace/workspace.go @@ -17,8 +17,10 @@ import ( "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/skills" + "github.com/charmbracelet/crush/internal/subagents" ) // LSPClientInfo holds information about an LSP client's state. This is @@ -144,6 +146,14 @@ type Workspace interface { InitializePrompt() (string, error) ListSkills(ctx context.Context) ([]skills.CatalogEntry, error) ReadSkill(ctx context.Context, skillID string) ([]byte, skills.SkillReadResult, error) + ActiveSubagents() []SubagentInfo + RunningSubagents(parentSessionID string) []RunningSubagentInfo + SubscribeSubagentRuntime(ctx context.Context) <-chan pubsub.Event[subagents.RuntimeEvent] + CancelSubagent(childSessionID string) + AllSubagents() []SubagentDefInfo + DeleteUserSubagent(name string) error + SetSubagentDisabled(name string, disabled bool) error + SessionTokens(ctx context.Context, sessionID string) (prompt, completion int64, err error) // MCP operations (server-side in client mode) MCPGetStates() map[string]mcptools.ClientInfo @@ -160,6 +170,37 @@ type Workspace interface { Shutdown() } +// SubagentInfo holds the minimal frontend-facing data for an active subagent. +type SubagentInfo struct { + Name string + Description string +} + +// RunningSubagentInfo holds frontend-facing data for a currently running +// subagent instance, enriched with session token counts. +type RunningSubagentInfo struct { + ChildSessionID string + ParentSessionID string + Name string + Color string + Model string + Status string + StartedAt time.Time + PromptTokens int64 + CompletionTokens int64 +} + +// SubagentDefInfo holds frontend-facing data for a discovered subagent +// definition, including its scope relative to the workspace. +type SubagentDefInfo struct { + Name string + Description string + Color string + FilePath string + Scope string // "user", "project", or "builtin" + Disabled bool +} + // MCPResourceContents holds the contents of an MCP resource. type MCPResourceContents struct { URI string `json:"uri"`