diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go
index 1130438a61..f6ec42e8b1 100644
--- a/internal/agent/coordinator.go
+++ b/internal/agent/coordinator.go
@@ -32,6 +32,7 @@ import (
"github.com/charmbracelet/crush/internal/oauth/copilot"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/charmbracelet/crush/internal/questions"
"github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/crush/internal/skills"
"golang.org/x/sync/errgroup"
@@ -82,6 +83,7 @@ type coordinator struct {
messages message.Service
permissions permission.Service
history history.Service
+ questions questions.Service
filetracker filetracker.Service
lspManager *lsp.Manager
notify pubsub.Publisher[notify.Notification]
@@ -104,6 +106,7 @@ func NewCoordinator(
messages message.Service,
permissions permission.Service,
history history.Service,
+ questions questions.Service,
filetracker filetracker.Service,
lspManager *lsp.Manager,
notify pubsub.Publisher[notify.Notification],
@@ -118,6 +121,7 @@ func NewCoordinator(
messages: messages,
permissions: permissions,
history: history,
+ questions: questions,
filetracker: filetracker,
lspManager: lspManager,
notify: notify,
@@ -470,6 +474,7 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
tools.NewJobKillTool(),
tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
+ tools.NewAskUserQuestionsTool(c.questions),
tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
tools.NewGlobTool(c.cfg.WorkingDir()),
diff --git a/internal/agent/tools/ask_user_questions.go b/internal/agent/tools/ask_user_questions.go
new file mode 100644
index 0000000000..02b1507e9a
--- /dev/null
+++ b/internal/agent/tools/ask_user_questions.go
@@ -0,0 +1,43 @@
+package tools
+
+import (
+ "context"
+ _ "embed"
+ "encoding/json"
+
+ "charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/questions"
+)
+
+const askUserQuestionsToolName = "ask_user_questions"
+
+//go:embed ask_user_questions.md
+var askUserQuestionDescription []byte
+
+// AskUserQuestionParams holds the parameters for the ask_user_questions tool.
+type AskUserQuestionParams struct {
+ Questions []questions.Question `json:"questions" description:"Array of questions to ask the user"`
+}
+
+// NewAskUserQuestionsTool creates a tool that asks the user multiple-choice questions.
+func NewAskUserQuestionsTool(questionsService questions.Service) fantasy.AgentTool {
+ return fantasy.NewAgentTool(
+ askUserQuestionsToolName,
+ string(askUserQuestionDescription),
+ func(ctx context.Context, params AskUserQuestionParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
+ sessionID := GetSessionFromContext(ctx)
+ req := questions.NewQuestionsRequest(sessionID, call.ID, params.Questions)
+
+ resp, err := questionsService.Ask(ctx, req)
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
+
+ jsonResp, err := json.Marshal(resp.Answers)
+ if err != nil {
+ return fantasy.ToolResponse{}, err
+ }
+ return fantasy.NewTextResponse(string(jsonResp)), nil
+ },
+ )
+}
diff --git a/internal/agent/tools/ask_user_questions.md b/internal/agent/tools/ask_user_questions.md
new file mode 100644
index 0000000000..28221fab3f
--- /dev/null
+++ b/internal/agent/tools/ask_user_questions.md
@@ -0,0 +1,48 @@
+Ask user questions to gather information, clarify ambiguity, taking decisions,
+determining user preferences and/or taste. When in doubt, use this tool to ask the user.
+
+
+- Provide array of questions to ask user
+- Each question needs:
+ - UUID to correlate originating question and answer
+ - short question text to ask user
+ - array of options from which user will select one or more answers
+ - boolean indicating if user can answer selecting multiple options
+- Each option needs:
+ - short label identifying the option
+ - optional description of the option
+
+
+
+- Ask user questions during execution
+- Gather user preferences or requirements
+- Clarifying ambiguous instructions
+- Get decisions on implementation choices as you work
+- Offer choices to the user about what direction to take
+- Determine user taste
+
+
+
+- If you recommend specific options, put them first in array of options and append "(Recommended)" to label
+- To let user pick something else, offer "None of the above" option
+- When planning
+ - use this tool to clarify requirements or choose between approaches BEFORE finalizing your plan.
+ - IMPORTANT: do not ask for feedback about the plan while you are working on it, because user cannot see it until you send it back
+
+
+
+- More questions can be asked to user at once
+- User can be allowed to select multiple options when answering
+
+
+
+- Keep question and option labels short
+- If option's description is too long, it will be truncated
+
+
+
+- Tool will report back array of answers from user
+- Each answer will include:
+ - UUID of originating question
+ - array of options' labels selected by user
+
diff --git a/internal/agent/tools/ask_user_questions_test.go b/internal/agent/tools/ask_user_questions_test.go
new file mode 100644
index 0000000000..a6050dc4ea
--- /dev/null
+++ b/internal/agent/tools/ask_user_questions_test.go
@@ -0,0 +1,179 @@
+package tools
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "testing"
+
+ "charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/charmbracelet/crush/internal/questions"
+ "github.com/stretchr/testify/require"
+)
+
+type mockQuestionsService struct {
+ *pubsub.Broker[questions.QuestionsRequest]
+ response questions.QuestionsResponse
+ err error
+}
+
+func (m *mockQuestionsService) Ask(_ context.Context, _ questions.QuestionsRequest) (questions.QuestionsResponse, error) {
+ if m.err != nil {
+ return questions.QuestionsResponse{}, m.err
+ }
+ return m.response, nil
+}
+
+func (m *mockQuestionsService) Answer(_ questions.QuestionsResponse) {
+}
+
+func TestAskUserQuestionsTool(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ questions []questions.Question
+ mockResponse questions.QuestionsResponse
+ mockErr error
+ expectedError string
+ }{
+ {
+ name: "single question with single answer",
+ questions: []questions.Question{
+ {
+ ID: "q1",
+ Question: "Do you like Go?",
+ Options: []questions.Option{
+ {Label: "Yes"},
+ {Label: "No"},
+ },
+ },
+ },
+ mockResponse: questions.QuestionsResponse{
+ RequestID: "req-1",
+ Answers: []questions.Answer{
+ {
+ ID: "q1",
+ Selected: []string{"Yes"},
+ },
+ },
+ },
+ },
+ {
+ name: "single question with multiple answers",
+ questions: []questions.Question{
+ {
+ ID: "q2",
+ Question: "Which languages do you know?",
+ MultiSelect: true,
+ Options: []questions.Option{
+ {Label: "Go"},
+ {Label: "Python"},
+ {Label: "Rust"},
+ },
+ },
+ },
+ mockResponse: questions.QuestionsResponse{
+ RequestID: "req-2",
+ Answers: []questions.Answer{
+ {
+ ID: "q2",
+ Selected: []string{"Go", "Rust"},
+ },
+ },
+ },
+ },
+ {
+ name: "multiple questions with mixed answers",
+ questions: []questions.Question{
+ {
+ ID: "q3",
+ Question: "Favorite editor?",
+ Options: []questions.Option{
+ {Label: "Vim"},
+ {Label: "Emacs"},
+ {Label: "VSCode"},
+ },
+ },
+ {
+ ID: "q4",
+ Question: "Preferred OS?",
+ MultiSelect: true,
+ Options: []questions.Option{
+ {Label: "Linux"},
+ {Label: "macOS"},
+ {Label: "Windows"},
+ },
+ },
+ },
+ mockResponse: questions.QuestionsResponse{
+ RequestID: "req-3",
+ Answers: []questions.Answer{
+ {
+ ID: "q3",
+ Selected: []string{"Vim"},
+ },
+ {
+ ID: "q4",
+ Selected: []string{"Linux", "macOS"},
+ },
+ },
+ },
+ },
+ {
+ name: "service returns error",
+ questions: []questions.Question{
+ {
+ ID: "q5",
+ Question: "Trigger error?",
+ Options: []questions.Option{{Label: "Yes"}},
+ },
+ },
+ mockErr: errors.New("service failure"),
+ expectedError: "service failure",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mockService := &mockQuestionsService{
+ Broker: pubsub.NewBroker[questions.QuestionsRequest](),
+ response: tt.mockResponse,
+ err: tt.mockErr,
+ }
+
+ tool := NewAskUserQuestionsTool(mockService)
+ ctx := context.WithValue(context.Background(), SessionIDContextKey, "test-session")
+
+ params := AskUserQuestionParams{
+ Questions: tt.questions,
+ }
+ input, err := json.Marshal(params)
+ require.NoError(t, err)
+
+ call := fantasy.ToolCall{
+ ID: "test-call",
+ Name: askUserQuestionsToolName,
+ Input: string(input),
+ }
+
+ resp, err := tool.Run(ctx, call)
+
+ if tt.expectedError != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tt.expectedError)
+ return
+ }
+
+ require.NoError(t, err)
+ require.False(t, resp.IsError)
+
+ var gotAnswers []questions.Answer
+ require.NoError(t, json.Unmarshal([]byte(resp.Content), &gotAnswers))
+ require.Equal(t, tt.mockResponse.Answers, gotAnswers)
+ })
+ }
+}
diff --git a/internal/app/app.go b/internal/app/app.go
index 987d0ec060..cfaf816dad 100644
--- a/internal/app/app.go
+++ b/internal/app/app.go
@@ -32,6 +32,7 @@ import (
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/charmbracelet/crush/internal/questions"
"github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/crush/internal/shell"
"github.com/charmbracelet/crush/internal/ui/anim"
@@ -55,6 +56,7 @@ type App struct {
Messages message.Service
History history.Service
Permissions permission.Service
+ Questions questions.Service
FileTracker filetracker.Service
AgentCoordinator agent.Coordinator
@@ -92,6 +94,7 @@ func New(ctx context.Context, conn *sql.DB, store *config.ConfigStore) (*App, er
Messages: messages,
History: files,
Permissions: permission.NewPermissionService(store.WorkingDir(), skipPermissionsRequests, allowedTools),
+ Questions: questions.NewService(),
FileTracker: filetracker.NewService(q),
LSPManager: lsp.NewManager(store),
@@ -477,6 +480,7 @@ func (app *App) setupEvents() {
setupSubscriber(ctx, app.serviceEventsWG, "messages", app.Messages.Subscribe, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events)
+ setupSubscriber(ctx, app.serviceEventsWG, "questions", app.Questions.Subscribe, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "agent-notifications", app.agentNotifications.Subscribe, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "mcp", mcp.SubscribeEvents, app.events)
@@ -549,6 +553,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error {
app.Messages,
app.Permissions,
app.History,
+ app.Questions,
app.FileTracker,
app.LSPManager,
app.agentNotifications,
diff --git a/internal/config/config.go b/internal/config/config.go
index cee8ab8c49..0fd157250b 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -461,6 +461,7 @@ const maxRecentModelsPerType = 5
func allToolNames() []string {
return []string{
"agent",
+ "ask_user_questions",
"bash",
"crush_info",
"crush_logs",
diff --git a/internal/config/load_test.go b/internal/config/load_test.go
index 68d52da39f..123dd61594 100644
--- a/internal/config/load_test.go
+++ b/internal/config/load_test.go
@@ -490,7 +490,7 @@ func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
coderAgent, ok := cfg.Agents[AgentCoder]
require.True(t, ok)
- assert.Equal(t, []string{"agent", "bash", "crush_info", "crush_logs", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "todos", "view", "write", "list_mcp_resources", "read_mcp_resource"}, coderAgent.AllowedTools)
+ assert.Equal(t, []string{"agent", "ask_user_questions", "bash", "crush_info", "crush_logs", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "todos", "view", "write", "list_mcp_resources", "read_mcp_resource"}, coderAgent.AllowedTools)
taskAgent, ok := cfg.Agents[AgentTask]
require.True(t, ok)
@@ -513,7 +513,7 @@ func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
cfg.SetupAgents()
coderAgent, ok := cfg.Agents[AgentCoder]
require.True(t, ok)
- assert.Equal(t, []string{"agent", "bash", "crush_info", "crush_logs", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "todos", "write", "list_mcp_resources", "read_mcp_resource"}, coderAgent.AllowedTools)
+ assert.Equal(t, []string{"agent", "ask_user_questions", "bash", "crush_info", "crush_logs", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "todos", "write", "list_mcp_resources", "read_mcp_resource"}, coderAgent.AllowedTools)
taskAgent, ok := cfg.Agents[AgentTask]
require.True(t, ok)
diff --git a/internal/questions/service.go b/internal/questions/service.go
new file mode 100644
index 0000000000..fc339630ed
--- /dev/null
+++ b/internal/questions/service.go
@@ -0,0 +1,154 @@
+package questions
+
+import (
+ "context"
+ "log/slog"
+
+ "github.com/charmbracelet/crush/internal/csync"
+ "github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/google/uuid"
+)
+
+// Option represents a single answer option for a question.
+type Option struct {
+ Label string `json:"label" description:"Short label identifying the option"`
+ Description string `json:"description" description:"Optional description of the option"`
+}
+
+// Question represents a single question to be asked.
+type Question struct {
+ ID string `json:"id" description:"UUID of originating question: used to correlate question and answer"`
+ Question string `json:"question" description:"Short question to ask the user"`
+
+ // Options are the possible Answer(s) to the question.
+ Options []Option `json:"options" description:"Array of options from which user will select one or more answers"`
+
+ // MultiSelect indicates whether the user can Answer by selecting multiple Option(s).
+ MultiSelect bool `json:"multi_select" description:"Indicates if user can answer selecting multiple options"`
+}
+
+// QuestionsRequest represents a request to ask a set of Question(s).
+type QuestionsRequest struct {
+ ID string
+ SessionID string
+ ToolCallID string
+ Questions []Question
+}
+
+// NewQuestionsRequest creates a new QuestionsRequest.
+func NewQuestionsRequest(sessionID string, toolCallID string, questions []Question) QuestionsRequest {
+ return QuestionsRequest{
+ ID: uuid.New().String(),
+ SessionID: sessionID,
+ ToolCallID: toolCallID,
+ Questions: questions,
+ }
+}
+
+// Answer represents a user's response to a Question.
+type Answer struct {
+ // ID is the ID of the original Question.
+ // This ensures that even if the LLM asks multiple Question(s),
+ // the Answer(s) can be easily correlated to the original Question.
+ ID string `json:"id" description:"UUID of originating question"`
+
+ // Selected is the index of the selected Option.
+ Selected []string `json:"selected" description:"Array of options' labels selected by user"`
+}
+
+func NewAnswer(question Question) Answer {
+ return Answer{
+ ID: question.ID,
+ Selected: []string{},
+ }
+}
+
+// Select set one or more user-selection Option(s) to the Answer.
+func (ans *Answer) Select(label ...string) {
+ ans.Selected = append(ans.Selected, label...)
+}
+
+// QuestionsResponse represents the set of Answer(s) to a set of Question(s)
+// contained in the originating QuestionsRequest.
+// The originating QuestionsRequest is identified by the RequestID.
+type QuestionsResponse struct {
+ RequestID string `json:"request_id"`
+ Answers []Answer `json:"answers"`
+}
+
+func NewQuestionsResponse(req *QuestionsRequest) QuestionsResponse {
+ return QuestionsResponse{
+ RequestID: req.ID,
+ Answers: make([]Answer, len(req.Questions)),
+ }
+}
+
+// SetAnswerAt sets the Answer at the given index.
+func (res *QuestionsResponse) SetAnswerAt(idx int, ans Answer) {
+ res.Answers[idx] = ans
+}
+
+// IsComplete returns true if the QuestionsResponse contains all the expected Answer(s).
+func (res *QuestionsResponse) IsComplete() bool {
+ return len(res.Answers) == cap(res.Answers)
+}
+
+// Service is the interface for the questionsService.
+// When Ask is invoked, a new QuestionsRequest is published to the service.
+// When the user answers the Question(s), the Answer(s) are sent back via the Answer method.
+type Service interface {
+ pubsub.Subscriber[QuestionsRequest]
+
+ Ask(ctx context.Context, req QuestionsRequest) (QuestionsResponse, error)
+ Answer(response QuestionsResponse)
+}
+
+// questionsService is a pubsub.Broker[QuestionsRequest] that tracks the pending
+// QuestionsResponse channels for each submitted QuestionsRequest.
+type questionsService struct {
+ *pubsub.Broker[QuestionsRequest]
+
+ // pendingRequests maps a QuestionsRequest.ID to a channel that will be used to
+ // send the QuestionsResponse back when the user answers the Question(s).
+ pendingRequests *csync.Map[string, chan QuestionsResponse]
+}
+
+// NewService creates a new questionsService.
+func NewService() Service {
+ return &questionsService{
+ Broker: pubsub.NewBroker[QuestionsRequest](),
+ pendingRequests: csync.NewMap[string, chan QuestionsResponse](),
+ }
+}
+
+func (s *questionsService) Ask(ctx context.Context, req QuestionsRequest) (QuestionsResponse, error) {
+ slog.Debug("Ask", "request_id", req.ID, "session_id", req.SessionID, "tool_call_id", req.ToolCallID, "questions", len(req.Questions))
+ ch := make(chan QuestionsResponse, 1)
+ s.pendingRequests.Set(req.ID, ch)
+ defer s.pendingRequests.Del(req.ID)
+
+ s.Publish(pubsub.CreatedEvent, req)
+
+ select {
+ // If the context is cancelled, return and empty AnswersResponse and the error.
+ case <-ctx.Done():
+ slog.Debug("Ask cancelled", "request_id", req.ID)
+ return QuestionsResponse{RequestID: req.ID}, ctx.Err()
+ // Otherwise, wait for the user to answer the Question(s).
+ case resp := <-ch:
+ return resp, nil
+ }
+}
+
+func (s *questionsService) Answer(res QuestionsResponse) {
+ if !res.IsComplete() {
+ slog.Warn("Incomplete response - missing answers", "want", cap(res.Answers), "got", len(res.Answers))
+ }
+
+ if ch, found := s.pendingRequests.Get(res.RequestID); found {
+ slog.Debug("Reporting answers from user", "request_id", res.RequestID, "answers", res.Answers)
+ ch <- res
+ } else {
+ slog.Warn("Received answers for unknown questions", "request_id", res.RequestID, "answers", res.Answers)
+ }
+}
diff --git a/internal/questions/service_test.go b/internal/questions/service_test.go
new file mode 100644
index 0000000000..1d2676aab7
--- /dev/null
+++ b/internal/questions/service_test.go
@@ -0,0 +1,212 @@
+package questions
+
+import (
+ "context"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestService_AskAndAnswer(t *testing.T) {
+ tests := []struct {
+ name string
+ setupContext func() (context.Context, context.CancelFunc)
+ questionsReq QuestionsRequest
+ handleAnswering func(srv Service, req QuestionsRequest)
+ expErr error
+ expAnswersRes QuestionsResponse
+ }{
+ {
+ name: "successful single answer",
+ setupContext: func() (context.Context, context.CancelFunc) {
+ return context.WithTimeout(context.Background(), 2*time.Second)
+ },
+ questionsReq: NewQuestionsRequest("sess1", "tool1", []Question{
+ {
+ ID: "q1",
+ Question: "Test?",
+ Options: []Option{{Label: "opt1"}},
+ MultiSelect: false,
+ },
+ }),
+ handleAnswering: func(srv Service, req QuestionsRequest) {
+ srv.Answer(QuestionsResponse{
+ RequestID: req.ID,
+ Answers: []Answer{{ID: "q1", Selected: []string{"opt1"}}},
+ })
+ },
+ expAnswersRes: QuestionsResponse{
+ Answers: []Answer{{ID: "q1", Selected: []string{"opt1"}}},
+ },
+ },
+ {
+ name: "successful multiple options, single select",
+ setupContext: func() (context.Context, context.CancelFunc) {
+ return context.WithTimeout(context.Background(), 2*time.Second)
+ },
+ questionsReq: NewQuestionsRequest("sess2", "tool2", []Question{
+ {
+ ID: "q2",
+ Question: "Choose one:",
+ Options: []Option{
+ {Label: "A"},
+ {Label: "B"},
+ {Label: "C"},
+ },
+ MultiSelect: false,
+ },
+ }),
+ handleAnswering: func(srv Service, req QuestionsRequest) {
+ srv.Answer(QuestionsResponse{
+ RequestID: req.ID,
+ Answers: []Answer{{ID: "q2", Selected: []string{"B"}}},
+ })
+ },
+ expAnswersRes: QuestionsResponse{
+ Answers: []Answer{{ID: "q2", Selected: []string{"B"}}},
+ },
+ },
+ {
+ name: "successful multiple options, multi select",
+ setupContext: func() (context.Context, context.CancelFunc) {
+ return context.WithTimeout(context.Background(), 2*time.Second)
+ },
+ questionsReq: NewQuestionsRequest("sess3", "tool3", []Question{
+ {
+ ID: "q3",
+ Question: "Choose many:",
+ Options: []Option{
+ {Label: "Apple"},
+ {Label: "Banana"},
+ {Label: "Cherry"},
+ },
+ MultiSelect: true,
+ },
+ }),
+ handleAnswering: func(srv Service, req QuestionsRequest) {
+ srv.Answer(QuestionsResponse{
+ RequestID: req.ID,
+ Answers: []Answer{{ID: "q3", Selected: []string{"Apple", "Cherry"}}},
+ })
+ },
+ expAnswersRes: QuestionsResponse{
+ Answers: []Answer{{ID: "q3", Selected: []string{"Apple", "Cherry"}}},
+ },
+ },
+ {
+ name: "context canceled",
+ setupContext: func() (context.Context, context.CancelFunc) {
+ // Cancel immediately
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ return ctx, cancel
+ },
+ questionsReq: NewQuestionsRequest("sess4", "tool4", []Question{
+ {ID: "q4", Question: "Cancel me?"},
+ }),
+ handleAnswering: func(srv Service, req QuestionsRequest) {
+ // Do nothing, let context cancellation abort Ask
+ },
+ expErr: context.Canceled,
+ // When cancelled, it sets RequestID to the requested ID
+ expAnswersRes: QuestionsResponse{},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ srv := NewService()
+ ctx, cancel := tt.setupContext()
+ defer cancel()
+
+ // Set the expected answers RequestID to the questions RequestID:
+ // the latter is a dynamically generated UUID, so we inject it
+ // into the expected answers response.
+ tt.expAnswersRes.RequestID = tt.questionsReq.ID
+
+ // Subscribe to wait for the Ask event so we know it's safe to Answer
+ subCtx, subCancel := context.WithCancel(context.Background())
+ defer subCancel()
+ sub := srv.Subscribe(subCtx)
+
+ // Answer the question in a separate goroutine
+ go func() {
+ select {
+ case <-sub:
+ tt.handleAnswering(srv, tt.questionsReq)
+ case <-time.After(1 * time.Second):
+ // Fallback to prevent hanging if Publish fails
+ }
+ }()
+
+ // Ask the questions and check the answers
+ resp, err := srv.Ask(ctx, tt.questionsReq)
+ if tt.expErr != nil {
+ require.ErrorIs(t, err, tt.expErr)
+ } else {
+ require.NoError(t, err)
+ }
+ require.Equal(t, tt.expAnswersRes, resp)
+
+ // Verify cleanup
+ s := srv.(*questionsService)
+ require.Equal(t, 0, s.pendingRequests.Len(), "pendingRequests should be empty after Ask returns")
+ })
+ }
+}
+
+func TestService_OrphanedAnswer(t *testing.T) {
+ srv := NewService()
+ // Answering an ID that doesn't exist should not panic or block
+ require.NotPanics(t, func() {
+ srv.Answer(QuestionsResponse{RequestID: "fake-id"})
+ })
+}
+
+func TestService_Concurrency(t *testing.T) {
+ srv := NewService()
+ var wg sync.WaitGroup
+
+ numRequests := 50
+ wg.Add(numRequests)
+
+ subCtx, subCancel := context.WithCancel(context.Background())
+ defer subCancel()
+ sub := srv.Subscribe(subCtx)
+
+ // Single goroutine to answer all questions as they come in via the pub/sub
+ go func() {
+ for event := range sub {
+ srv.Answer(QuestionsResponse{
+ RequestID: event.Payload.ID,
+ Answers: []Answer{{ID: "q1", Selected: []string{"opt1"}}},
+ })
+ }
+ }()
+
+ for i := 0; i < numRequests; i++ {
+ go func(idx int) {
+ defer wg.Done()
+
+ req := NewQuestionsRequest("sess1", "tool1", []Question{
+ {ID: "q1", Question: "Test?"},
+ })
+
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ resp, err := srv.Ask(ctx, req)
+ require.NoError(t, err)
+ require.Len(t, resp.Answers, 1)
+ require.Equal(t, "opt1", resp.Answers[0].Selected[0])
+ }(i)
+ }
+
+ wg.Wait()
+
+ // Verify cleanup
+ s := srv.(*questionsService)
+ require.Equal(t, 0, s.pendingRequests.Len(), "pendingRequests should be empty after all Asks return")
+}
diff --git a/internal/ui/dialog/actions.go b/internal/ui/dialog/actions.go
index a2de6513c1..79c1c57c0d 100644
--- a/internal/ui/dialog/actions.go
+++ b/internal/ui/dialog/actions.go
@@ -13,6 +13,7 @@ import (
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/oauth"
"github.com/charmbracelet/crush/internal/permission"
+ "github.com/charmbracelet/crush/internal/questions"
"github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/crush/internal/ui/common"
"github.com/charmbracelet/crush/internal/ui/util"
@@ -85,6 +86,10 @@ type (
ActionEnableDockerMCP struct{}
// ActionDisableDockerMCP is a message to disable Docker MCP.
ActionDisableDockerMCP struct{}
+ // ActionQuestionsResponse is sent when the user completes an ask_question dialog.
+ ActionQuestionsResponse struct {
+ Response questions.QuestionsResponse
+ }
)
// Messages for API key input dialog.
diff --git a/internal/ui/dialog/models_item.go b/internal/ui/dialog/models_item.go
index 645b26e987..6103be7ebf 100644
--- a/internal/ui/dialog/models_item.go
+++ b/internal/ui/dialog/models_item.go
@@ -104,15 +104,15 @@ func (m *ModelItem) ID() string {
func (m *ModelItem) Render(width int) string {
var providerInfo string
if m.showProvider {
- providerInfo = string(m.prov.Name)
+ providerInfo = m.prov.Name
}
- styles := ListItemStyles{
+ miStyles := ListItemStyles{
ItemBlurred: m.t.Dialog.NormalItem,
ItemFocused: m.t.Dialog.SelectedItem,
InfoTextBlurred: m.t.Base,
InfoTextFocused: m.t.Base,
}
- return renderItem(styles, m.model.Name, providerInfo, m.focused, width, m.cache, &m.m)
+ return renderItem(miStyles, m.model.Name, providerInfo, m.focused, width, m.cache, &m.m)
}
// SetFocused implements ListItem.
diff --git a/internal/ui/dialog/question_options.go b/internal/ui/dialog/question_options.go
new file mode 100644
index 0000000000..68a400d3f9
--- /dev/null
+++ b/internal/ui/dialog/question_options.go
@@ -0,0 +1,96 @@
+package dialog
+
+import (
+ "strings"
+
+ "charm.land/lipgloss/v2"
+ "github.com/charmbracelet/crush/internal/questions"
+ "github.com/charmbracelet/crush/internal/ui/list"
+ "github.com/charmbracelet/crush/internal/ui/styles"
+ "github.com/charmbracelet/x/ansi"
+)
+
+// questionOptionsList is a list of options for a question.
+type questionOptionsList struct {
+ *list.List
+
+ t *styles.Styles
+}
+
+// newQuestionOptionsList creates a new list of options for a question.
+func newQuestionOptionsList(sty *styles.Styles) *questionOptionsList {
+ l := &questionOptionsList{
+ List: list.NewList(),
+ t: sty,
+ }
+ l.RegisterRenderCallback(list.FocusedRenderCallback(l.List))
+ return l
+}
+
+// SetQuestion sets the question's options in the list,
+// and a map of which option is selected.
+func (l *questionOptionsList) SetQuestion(q questions.Question, selOpts map[int]bool) {
+ var items []list.Item
+ for i, opt := range q.Options {
+ items = append(items, &questionOptionsListItem{
+ parent: l,
+ opt: opt,
+ selected: selOpts[i],
+ index: i,
+ })
+ }
+ l.SetItems(items...)
+}
+
+// questionOptionsListItem is a list item for a question's option.
+type questionOptionsListItem struct {
+ parent *questionOptionsList
+ opt questions.Option
+ selected bool
+ focused bool
+ index int
+}
+
+func (i *questionOptionsListItem) Height() int {
+ return 1
+}
+
+func (i *questionOptionsListItem) String() string {
+ return i.opt.Label
+}
+
+// SetFocused implements ListItem.
+func (i *questionOptionsListItem) SetFocused(focused bool) {
+ i.focused = focused
+}
+
+func (i *questionOptionsListItem) Render(width int) string {
+ t := i.parent.t
+
+ // Setup styles
+ radioStyle := t.RadioOff.Bold(true)
+ if i.selected {
+ radioStyle = t.RadioOn.Foreground(t.Green).Bold(true)
+ }
+ labelStyle := t.Dialog.NormalItem
+ if i.focused {
+ labelStyle = t.Dialog.SelectedItem
+ }
+ descStyle := labelStyle.Italic(true).Foreground(t.FgHalfMuted)
+ gapStyle := labelStyle.Padding(0)
+
+ // Render each part
+ radioRender := radioStyle.Render()
+ labelRender := labelStyle.Render(i.opt.Label)
+ // NOTE: Only render the portion of description that fits on one line
+ descRender := ""
+ if len(i.opt.Description) > 0 {
+ // NOTE: `-2` is for the padding of the style
+ descAvailWidth := width - lipgloss.Width(radioRender) - lipgloss.Width(labelRender) - 2
+ optDesc := ansi.Truncate(i.opt.Description, max(0, descAvailWidth), "…")
+ descRender = descStyle.Render(optDesc)
+ }
+ gapRender := gapStyle.Render(strings.Repeat(" ", max(0, width-lipgloss.Width(radioRender)-lipgloss.Width(labelRender)-lipgloss.Width(descRender))))
+
+ return labelStyle.Render(radioRender + labelRender + gapRender + descRender)
+}
diff --git a/internal/ui/dialog/questions.go b/internal/ui/dialog/questions.go
new file mode 100644
index 0000000000..1e4a282e63
--- /dev/null
+++ b/internal/ui/dialog/questions.go
@@ -0,0 +1,234 @@
+package dialog
+
+import (
+ "log/slog"
+
+ "charm.land/bubbles/v2/help"
+ "charm.land/bubbles/v2/key"
+ tea "charm.land/bubbletea/v2"
+ "github.com/charmbracelet/crush/internal/questions"
+ "github.com/charmbracelet/crush/internal/ui/common"
+ uv "github.com/charmbracelet/ultraviolet"
+)
+
+const QuestionsID = "questions"
+
+type Questions struct {
+ com *common.Common
+
+ // Input
+ req questions.QuestionsRequest
+
+ // State
+ currQuestion int
+ selectedOptsByQuestion map[int]map[int]bool // map[questionIdx]map[optionIdx]bool
+
+ // Keyboard
+ keyMap struct {
+ UpDown key.Binding
+ Next key.Binding
+ Previous key.Binding
+ Select key.Binding
+ Submit key.Binding
+ Close key.Binding
+ }
+
+ // UI
+ list *questionOptionsList
+ help help.Model
+}
+
+func NewQuestionsDialog(com *common.Common, req questions.QuestionsRequest) *Questions {
+ d := &Questions{
+ com: com,
+ req: req,
+ currQuestion: 0,
+ selectedOptsByQuestion: make(map[int]map[int]bool),
+ list: newQuestionOptionsList(com.Styles),
+ help: help.New(),
+ }
+
+ d.keyMap.UpDown = key.NewBinding(
+ key.WithKeys("up", "down"),
+ key.WithHelp("↑/↓", "choose"),
+ )
+ d.keyMap.Next = key.NewBinding(
+ key.WithKeys("down", "ctrl+n"),
+ key.WithHelp("↓", "next option"),
+ )
+ d.keyMap.Previous = key.NewBinding(
+ key.WithKeys("up", "ctrl+p"),
+ key.WithHelp("↑", "previous option"),
+ )
+ d.keyMap.Select = key.NewBinding(
+ key.WithKeys("space"),
+ key.WithHelp("space", "select"),
+ )
+ d.keyMap.Submit = key.NewBinding(
+ key.WithKeys("enter"),
+ key.WithHelp("enter", "confirm"),
+ )
+ d.keyMap.Close = CloseKey
+
+ d.list.Focus()
+ d.initList()
+ d.list.SetSelected(0)
+
+ d.help.Styles = com.Styles.DialogHelpStyles()
+
+ return d
+}
+
+func (q *Questions) ID() string {
+ return QuestionsID
+}
+
+func (q *Questions) initList() {
+ // Return early if there are no questions (it should never happen)
+ if len(q.req.Questions) == 0 {
+ return
+ }
+
+ // Initialize map of selected options for current question
+ if q.selectedOptsByQuestion[q.currQuestion] == nil {
+ q.selectedOptsByQuestion[q.currQuestion] = make(map[int]bool)
+ }
+
+ q.refreshList()
+ q.list.SelectFirst()
+}
+
+func (q *Questions) refreshList() {
+ q.list.SetQuestion(
+ q.req.Questions[q.currQuestion],
+ q.selectedOptsByQuestion[q.currQuestion],
+ )
+}
+
+func (q *Questions) HandleMsg(msg tea.Msg) Action {
+ switch msg := msg.(type) {
+ case tea.KeyMsg:
+ switch {
+ case key.Matches(msg, q.keyMap.Previous):
+ q.list.Focus()
+ if q.list.IsSelectedFirst() {
+ q.list.SelectLast()
+ } else {
+ q.list.SelectPrev()
+ }
+ q.list.ScrollToSelected()
+ case key.Matches(msg, q.keyMap.Next):
+ q.list.Focus()
+ if q.list.IsSelectedLast() {
+ q.list.SelectFirst()
+ } else {
+ q.list.SelectNext()
+ }
+ q.list.ScrollToSelected()
+ case key.Matches(msg, q.keyMap.Select):
+ currQ := q.req.Questions[q.currQuestion]
+ idx := q.list.Selected()
+ if idx < 0 {
+ break
+ }
+ if !currQ.MultiSelect {
+ q.selectedOptsByQuestion[q.currQuestion] = make(map[int]bool)
+ q.selectedOptsByQuestion[q.currQuestion][idx] = true
+ } else {
+ q.selectedOptsByQuestion[q.currQuestion][idx] = !q.selectedOptsByQuestion[q.currQuestion][idx]
+ }
+ q.refreshList()
+ case key.Matches(msg, q.keyMap.Submit):
+ if q.currQuestion < len(q.req.Questions)-1 {
+ q.currQuestion++
+ q.initList()
+ } else {
+ slog.Info("Submitting QuestionsDialog with selected answers")
+
+ // Loop over all the Questions to assemble the Answers response
+ res := questions.NewQuestionsResponse(&q.req)
+ for questIdx, quest := range q.req.Questions {
+ // Create an Answer for each Question
+ ans := questions.NewAnswer(quest)
+ for optIdx, optSelected := range q.selectedOptsByQuestion[questIdx] {
+ // If the option is selected, select it on the Answer too
+ if optSelected {
+ ans.Select(quest.Options[optIdx].Label)
+ }
+ }
+ res.SetAnswerAt(questIdx, ans)
+ }
+
+ return ActionQuestionsResponse{Response: res}
+ }
+ case key.Matches(msg, q.keyMap.Close):
+ {
+ slog.Info("Closing QuestionsDialog: no answers provided, returning empty response")
+ return ActionQuestionsResponse{Response: questions.NewQuestionsResponse(&q.req)}
+ }
+ }
+ }
+ return nil
+}
+
+func (q *Questions) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
+ // Return early if there are no questions (it should never happen)
+ if len(q.req.Questions) == 0 {
+ return nil
+ }
+ // Determine current question
+ currQ := q.req.Questions[q.currQuestion]
+
+ // Styles shorthand
+ t := q.com.Styles
+
+ // Figure out dimensions
+ width := max(0, min(defaultDialogMaxWidth, area.Dx()-t.Dialog.View.GetHorizontalBorderSize()))
+ height := max(0, min(defaultDialogHeight, area.Dy()-t.Dialog.View.GetVerticalBorderSize()))
+ innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
+ heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight +
+ t.Dialog.InputPrompt.GetVerticalFrameSize() + inputContentHeight +
+ t.Dialog.HelpView.GetVerticalFrameSize() +
+ t.Dialog.View.GetVerticalFrameSize()
+
+ // Set dimensions for List and Help bar
+ q.list.SetSize(innerWidth, height-heightOffset)
+ q.help.SetWidth(innerWidth)
+
+ rc := NewRenderContext(t, width)
+
+ // Render: Title
+ rc.Title = "Question"
+
+ // Render: Question
+ questionText := t.Dialog.TitleAccent.Italic(true).Padding(1, 2).Render(currQ.Question)
+ rc.AddPart(questionText)
+
+ // Render: Question's Options
+ listView := t.Dialog.List.Height(q.list.Height()).Render(q.list.Render())
+ rc.AddPart(listView)
+
+ // Render: Help
+ rc.Help = q.help.View(q)
+
+ view := rc.Render()
+ DrawCenterCursor(scr, area, view, nil)
+
+ return nil
+}
+
+// ShortHelp returns the short help view.
+func (q *Questions) ShortHelp() []key.Binding {
+ h := []key.Binding{
+ q.keyMap.UpDown,
+ q.keyMap.Select,
+ q.keyMap.Submit,
+ }
+ h = append(h, q.keyMap.Close)
+ return h
+}
+
+// FullHelp returns the full help view.
+func (q *Questions) FullHelp() [][]key.Binding {
+ return [][]key.Binding{q.ShortHelp()}
+}
diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go
index f631317523..ea6d7dbe14 100644
--- a/internal/ui/model/ui.go
+++ b/internal/ui/model/ui.go
@@ -37,6 +37,7 @@ import (
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/charmbracelet/crush/internal/questions"
"github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/crush/internal/ui/anim"
"github.com/charmbracelet/crush/internal/ui/attachments"
@@ -644,6 +645,13 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
case pubsub.Event[permission.PermissionNotification]:
m.handlePermissionNotification(msg.Payload)
+ case pubsub.Event[questions.QuestionsRequest]:
+ d := dialog.NewQuestionsDialog(m.com, msg.Payload)
+ m.dialog.OpenDialog(d)
+ cmds = append(cmds, m.sendNotification(notification.Notification{
+ Title: "Crush is waiting...",
+ Message: "There is a question for you.",
+ }))
case cancelTimerExpiredMsg:
m.isCanceling = false
case tea.TerminalVersionMsg:
@@ -1506,6 +1514,9 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
return util.NewInfoMsg("Reasoning effort set to " + msg.Effort)
})
m.dialog.CloseDialog(dialog.ReasoningID)
+ case dialog.ActionQuestionsResponse:
+ m.com.Workspace.QuestionsAnswer(msg.Response)
+ m.dialog.CloseDialog(dialog.QuestionsID)
case dialog.ActionPermissionResponse:
m.dialog.CloseDialog(dialog.PermissionsID)
switch msg.Action {
diff --git a/internal/workspace/app_workspace.go b/internal/workspace/app_workspace.go
index 57b1228e7e..e0ac2120c2 100644
--- a/internal/workspace/app_workspace.go
+++ b/internal/workspace/app_workspace.go
@@ -17,6 +17,7 @@ import (
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/oauth"
"github.com/charmbracelet/crush/internal/permission"
+ "github.com/charmbracelet/crush/internal/questions"
"github.com/charmbracelet/crush/internal/session"
)
@@ -187,6 +188,12 @@ func (w *AppWorkspace) PermissionSetSkipRequests(skip bool) {
w.app.Permissions.SetSkipRequests(skip)
}
+// -- Questions --
+
+func (w *AppWorkspace) QuestionsAnswer(response questions.QuestionsResponse) {
+ w.app.Questions.Answer(response)
+}
+
// -- FileTracker --
func (w *AppWorkspace) FileTrackerRecordRead(ctx context.Context, sessionID, path string) {
diff --git a/internal/workspace/client_workspace.go b/internal/workspace/client_workspace.go
index 7c4e140888..eb82fd21de 100644
--- a/internal/workspace/client_workspace.go
+++ b/internal/workspace/client_workspace.go
@@ -21,6 +21,7 @@ import (
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/charmbracelet/crush/internal/questions"
"github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
)
@@ -304,6 +305,10 @@ func (w *ClientWorkspace) PermissionSetSkipRequests(skip bool) {
_ = w.client.SetPermissionsSkipRequests(context.Background(), w.workspaceID(), skip)
}
+func (w *ClientWorkspace) QuestionsAnswer(res questions.QuestionsResponse) {
+ // TODO Implement for ClientWorkspace.
+}
+
// -- FileTracker --
func (w *ClientWorkspace) FileTrackerRecordRead(ctx context.Context, sessionID, path string) {
diff --git a/internal/workspace/workspace.go b/internal/workspace/workspace.go
index 02c54c616f..01f2847757 100644
--- a/internal/workspace/workspace.go
+++ b/internal/workspace/workspace.go
@@ -17,6 +17,7 @@ import (
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/oauth"
"github.com/charmbracelet/crush/internal/permission"
+ "github.com/charmbracelet/crush/internal/questions"
"github.com/charmbracelet/crush/internal/session"
)
@@ -95,6 +96,9 @@ type Workspace interface {
PermissionSkipRequests() bool
PermissionSetSkipRequests(skip bool)
+ // Questions
+ QuestionsAnswer(res questions.QuestionsResponse)
+
// FileTracker
FileTrackerRecordRead(ctx context.Context, sessionID, path string)
FileTrackerLastReadTime(ctx context.Context, sessionID, path string) time.Time