Skip to content
Open
5 changes: 5 additions & 0 deletions internal/agent/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/charmbracelet/crush/internal/oauth/copilot"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/charmbracelet/crush/internal/questions"
"github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/crush/internal/skills"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -82,6 +83,7 @@ type coordinator struct {
messages message.Service
permissions permission.Service
history history.Service
questions questions.Service
filetracker filetracker.Service
lspManager *lsp.Manager
notify pubsub.Publisher[notify.Notification]
Expand All @@ -104,6 +106,7 @@ func NewCoordinator(
messages message.Service,
permissions permission.Service,
history history.Service,
questions questions.Service,
filetracker filetracker.Service,
lspManager *lsp.Manager,
notify pubsub.Publisher[notify.Notification],
Expand All @@ -118,6 +121,7 @@ func NewCoordinator(
messages: messages,
permissions: permissions,
history: history,
questions: questions,
filetracker: filetracker,
lspManager: lspManager,
notify: notify,
Expand Down Expand Up @@ -470,6 +474,7 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
tools.NewJobKillTool(),
tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
tools.NewAskUserQuestionsTool(c.questions),
tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
tools.NewGlobTool(c.cfg.WorkingDir()),
Expand Down
43 changes: 43 additions & 0 deletions internal/agent/tools/ask_user_questions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package tools

import (
"context"
_ "embed"
"encoding/json"

"charm.land/fantasy"
"github.com/charmbracelet/crush/internal/questions"
)

const askUserQuestionsToolName = "ask_user_questions"

//go:embed ask_user_questions.md
var askUserQuestionDescription []byte

// AskUserQuestionParams holds the parameters for the ask_user_questions tool.
type AskUserQuestionParams struct {
Questions []questions.Question `json:"questions" description:"Array of questions to ask the user"`
}

// NewAskUserQuestionsTool creates a tool that asks the user multiple-choice questions.
func NewAskUserQuestionsTool(questionsService questions.Service) fantasy.AgentTool {
return fantasy.NewAgentTool(
askUserQuestionsToolName,
string(askUserQuestionDescription),
func(ctx context.Context, params AskUserQuestionParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
sessionID := GetSessionFromContext(ctx)
req := questions.NewQuestionsRequest(sessionID, call.ID, params.Questions)

resp, err := questionsService.Ask(ctx, req)
if err != nil {
return fantasy.ToolResponse{}, err
}

jsonResp, err := json.Marshal(resp.Answers)
if err != nil {
return fantasy.ToolResponse{}, err
}
return fantasy.NewTextResponse(string(jsonResp)), nil
},
)
}
48 changes: 48 additions & 0 deletions internal/agent/tools/ask_user_questions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
Ask user questions to gather information, clarify ambiguity, taking decisions,
determining user preferences and/or taste. When in doubt, use this tool to ask the user.

<usage>
- Provide array of questions to ask user
- Each question needs:
- UUID to correlate originating question and answer
- short question text to ask user
- array of options from which user will select one or more answers
- boolean indicating if user can answer selecting multiple options
- Each option needs:
- short label identifying the option
- optional description of the option
</usage>

<when_to_use>
- Ask user questions during execution
- Gather user preferences or requirements
- Clarifying ambiguous instructions
- Get decisions on implementation choices as you work
- Offer choices to the user about what direction to take
- Determine user taste
</when_to_use>

<tips>
- If you recommend specific options, put them first in array of options and append "(Recommended)" to label
- To let user pick something else, offer "None of the above" option
- When planning
- use this tool to clarify requirements or choose between approaches BEFORE finalizing your plan.
- IMPORTANT: do not ask for feedback about the plan while you are working on it, because user cannot see it until you send it back
</tips>

<features>
- More questions can be asked to user at once
- User can be allowed to select multiple options when answering
</features>

<limitations>
- Keep question and option labels short
- If option's description is too long, it will be truncated
</limitations>

<result>
- Tool will report back array of answers from user
- Each answer will include:
- UUID of originating question
- array of options' labels selected by user
</result>
179 changes: 179 additions & 0 deletions internal/agent/tools/ask_user_questions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package tools

import (
"context"
"encoding/json"
"errors"
"testing"

"charm.land/fantasy"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/charmbracelet/crush/internal/questions"
"github.com/stretchr/testify/require"
)

type mockQuestionsService struct {
*pubsub.Broker[questions.QuestionsRequest]
response questions.QuestionsResponse
err error
}

func (m *mockQuestionsService) Ask(_ context.Context, _ questions.QuestionsRequest) (questions.QuestionsResponse, error) {
if m.err != nil {
return questions.QuestionsResponse{}, m.err
}
return m.response, nil
}

func (m *mockQuestionsService) Answer(_ questions.QuestionsResponse) {
}

func TestAskUserQuestionsTool(t *testing.T) {
t.Parallel()

tests := []struct {
name string
questions []questions.Question
mockResponse questions.QuestionsResponse
mockErr error
expectedError string
}{
{
name: "single question with single answer",
questions: []questions.Question{
{
ID: "q1",
Question: "Do you like Go?",
Options: []questions.Option{
{Label: "Yes"},
{Label: "No"},
},
},
},
mockResponse: questions.QuestionsResponse{
RequestID: "req-1",
Answers: []questions.Answer{
{
ID: "q1",
Selected: []string{"Yes"},
},
},
},
},
{
name: "single question with multiple answers",
questions: []questions.Question{
{
ID: "q2",
Question: "Which languages do you know?",
MultiSelect: true,
Options: []questions.Option{
{Label: "Go"},
{Label: "Python"},
{Label: "Rust"},
},
},
},
mockResponse: questions.QuestionsResponse{
RequestID: "req-2",
Answers: []questions.Answer{
{
ID: "q2",
Selected: []string{"Go", "Rust"},
},
},
},
},
{
name: "multiple questions with mixed answers",
questions: []questions.Question{
{
ID: "q3",
Question: "Favorite editor?",
Options: []questions.Option{
{Label: "Vim"},
{Label: "Emacs"},
{Label: "VSCode"},
},
},
{
ID: "q4",
Question: "Preferred OS?",
MultiSelect: true,
Options: []questions.Option{
{Label: "Linux"},
{Label: "macOS"},
{Label: "Windows"},
},
},
},
mockResponse: questions.QuestionsResponse{
RequestID: "req-3",
Answers: []questions.Answer{
{
ID: "q3",
Selected: []string{"Vim"},
},
{
ID: "q4",
Selected: []string{"Linux", "macOS"},
},
},
},
},
{
name: "service returns error",
questions: []questions.Question{
{
ID: "q5",
Question: "Trigger error?",
Options: []questions.Option{{Label: "Yes"}},
},
},
mockErr: errors.New("service failure"),
expectedError: "service failure",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

mockService := &mockQuestionsService{
Broker: pubsub.NewBroker[questions.QuestionsRequest](),
response: tt.mockResponse,
err: tt.mockErr,
}

tool := NewAskUserQuestionsTool(mockService)
ctx := context.WithValue(context.Background(), SessionIDContextKey, "test-session")

params := AskUserQuestionParams{
Questions: tt.questions,
}
input, err := json.Marshal(params)
require.NoError(t, err)

call := fantasy.ToolCall{
ID: "test-call",
Name: askUserQuestionsToolName,
Input: string(input),
}

resp, err := tool.Run(ctx, call)

if tt.expectedError != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.expectedError)
return
}

require.NoError(t, err)
require.False(t, resp.IsError)

var gotAnswers []questions.Answer
require.NoError(t, json.Unmarshal([]byte(resp.Content), &gotAnswers))
require.Equal(t, tt.mockResponse.Answers, gotAnswers)
})
}
}
5 changes: 5 additions & 0 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/charmbracelet/crush/internal/questions"
"github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/crush/internal/shell"
"github.com/charmbracelet/crush/internal/ui/anim"
Expand All @@ -55,6 +56,7 @@ type App struct {
Messages message.Service
History history.Service
Permissions permission.Service
Questions questions.Service
FileTracker filetracker.Service

AgentCoordinator agent.Coordinator
Expand Down Expand Up @@ -92,6 +94,7 @@ func New(ctx context.Context, conn *sql.DB, store *config.ConfigStore) (*App, er
Messages: messages,
History: files,
Permissions: permission.NewPermissionService(store.WorkingDir(), skipPermissionsRequests, allowedTools),
Questions: questions.NewService(),
FileTracker: filetracker.NewService(q),
LSPManager: lsp.NewManager(store),

Expand Down Expand Up @@ -477,6 +480,7 @@ func (app *App) setupEvents() {
setupSubscriber(ctx, app.serviceEventsWG, "messages", app.Messages.Subscribe, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "questions", app.Questions.Subscribe, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "agent-notifications", app.agentNotifications.Subscribe, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "mcp", mcp.SubscribeEvents, app.events)
Expand Down Expand Up @@ -549,6 +553,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error {
app.Messages,
app.Permissions,
app.History,
app.Questions,
app.FileTracker,
app.LSPManager,
app.agentNotifications,
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ const maxRecentModelsPerType = 5
func allToolNames() []string {
return []string{
"agent",
"ask_user_questions",
"bash",
"crush_info",
"crush_logs",
Expand Down
4 changes: 2 additions & 2 deletions internal/config/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
coderAgent, ok := cfg.Agents[AgentCoder]
require.True(t, ok)

assert.Equal(t, []string{"agent", "bash", "crush_info", "crush_logs", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "todos", "view", "write", "list_mcp_resources", "read_mcp_resource"}, coderAgent.AllowedTools)
assert.Equal(t, []string{"agent", "ask_user_questions", "bash", "crush_info", "crush_logs", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "todos", "view", "write", "list_mcp_resources", "read_mcp_resource"}, coderAgent.AllowedTools)

taskAgent, ok := cfg.Agents[AgentTask]
require.True(t, ok)
Expand All @@ -513,7 +513,7 @@ func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
cfg.SetupAgents()
coderAgent, ok := cfg.Agents[AgentCoder]
require.True(t, ok)
assert.Equal(t, []string{"agent", "bash", "crush_info", "crush_logs", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "todos", "write", "list_mcp_resources", "read_mcp_resource"}, coderAgent.AllowedTools)
assert.Equal(t, []string{"agent", "ask_user_questions", "bash", "crush_info", "crush_logs", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "todos", "write", "list_mcp_resources", "read_mcp_resource"}, coderAgent.AllowedTools)

taskAgent, ok := cfg.Agents[AgentTask]
require.True(t, ok)
Expand Down
Loading
Loading