diff --git a/AGENTS.md b/AGENTS.md index e668316..acac344 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -29,7 +29,9 @@ CRITICAL: every Playwright artifact save (`playwright_browser_take_screenshot`, CRITICAL: bare Playwright filenames like `foo.png`, `foo.md`, or `foo.txt` are forbidden because they save into the repository root and pollute the workspace. -For React/TypeScript code in cmd/sgai/webapp/, use bun for building, testing, and running scripts. Build command: `bun run build`. Dev server: `bun run dev.ts`. Tests: `bun test`. +For React/TypeScript code in cmd/sgai/webapp/, use bun for building, testing, and running scripts. Run these commands from `cmd/sgai/webapp/`. Build command: `bun run build`. Dev server: `bun run dev.ts`. Tests: `bun test src/`. + +Bare repo-root `bun test` is misleading in this repository because it can sweep unrelated root-level TypeScript/Playwright files instead of just the webapp package. React components must use shadcn/ui components where possible. Do not create custom implementations when a shadcn component exists. Reference: https://ui.shadcn.com/docs diff --git a/cmd/sgai/exhaustruct_test.go b/cmd/sgai/exhaustruct_test.go index 0f0e24f..c9fa1eb 100644 --- a/cmd/sgai/exhaustruct_test.go +++ b/cmd/sgai/exhaustruct_test.go @@ -38,11 +38,14 @@ func testWorkflow() state.Workflow { Task: "", Progress: nil, HumanMessage: "", + HumanInputAgent: "", MultiChoiceQuestion: nil, Messages: nil, VisitCounts: nil, CurrentAgent: "", + AgentStates: nil, Todos: nil, + TodosByAgent: nil, ProjectTodos: nil, AgentSequence: nil, SessionID: "", diff --git a/cmd/sgai/main.go b/cmd/sgai/main.go index 71e1428..004feea 100644 --- a/cmd/sgai/main.go +++ b/cmd/sgai/main.go @@ -138,7 +138,9 @@ func unlockInteractiveForRetrospective(wfState *state.Workflow, currentAgent str return nil } wfState.InteractionMode = state.ModeRetrospective - if errSave := saveState(coord, wfState); errSave != nil { + if errSave := coord.UpdateState(func(currentState *state.Workflow) { + currentState.InteractionMode = state.ModeRetrospective + }); errSave != nil { return fmt.Errorf("save state for retrospective unlock: %w", errSave) } log.Println("["+paddedsgai+"]", "transitioning to retrospective mode") @@ -305,17 +307,18 @@ type multiModelConfig struct { logWriter io.Writer stdoutLog io.Writer stderrLog io.Writer + nextIteration func() int } -func runMultiModelAgent(ctx context.Context, cfg *multiModelConfig, wfState *state.Workflow, metadata *GoalMetadata, iterationCounter *int) state.Workflow { +func runMultiModelAgent(ctx context.Context, cfg *multiModelConfig, wfState *state.Workflow, metadata *GoalMetadata) state.Workflow { currentState := *wfState models := getModelsForAgent(metadata.Models, cfg.agent) if len(models) <= 1 { - return runSingleModelIteration(ctx, cfg, ¤tState, metadata, iterationCounter, models) + return runSingleModelIteration(ctx, cfg, ¤tState, metadata, models) } currentState.ModelStatuses = syncModelStatuses(currentState.ModelStatuses, models, cfg.agent) - if errSave := saveState(cfg.coord, ¤tState); errSave != nil { + if errSave := persistMultiModelState(cfg.coord, cfg.agent, currentState.ModelStatuses, currentState.CurrentModel, ""); errSave != nil { return failWorkflowState(cfg, ¤tState, "failed to save state before multi-model loop: %v", errSave) } @@ -335,7 +338,7 @@ func runMultiModelAgent(ctx context.Context, cfg *multiModelConfig, wfState *sta if len(newModels) <= 1 { log.Println("["+cfg.paddedsgai+"]", "switching to single-model mode for", cfg.agent) cleanupModelStatuses(¤tState) - return runSingleModelIteration(ctx, cfg, ¤tState, metadata, iterationCounter, newModels) + return runSingleModelIteration(ctx, cfg, ¤tState, metadata, newModels) } currentState.ModelStatuses = syncModelStatuses(currentState.ModelStatuses, newModels, cfg.agent) @@ -362,12 +365,12 @@ func runMultiModelAgent(ctx context.Context, cfg *multiModelConfig, wfState *sta } currentState.CurrentModel = modelID - if errSave := saveState(cfg.coord, ¤tState); errSave != nil { + if errSave := persistMultiModelState(cfg.coord, cfg.agent, currentState.ModelStatuses, currentState.CurrentModel, ""); errSave != nil { return failWorkflowState(cfg, ¤tState, "failed to save state before model iteration: %v", errSave) } log.Println("["+cfg.paddedsgai+"]", "running model:", modelID) - currentState = runSingleModelIteration(ctx, cfg, ¤tState, metadata, iterationCounter, []string{modelSpec}) + currentState = runSingleModelIteration(ctx, cfg, ¤tState, metadata, []string{modelSpec}) newState := cfg.coord.State() @@ -375,7 +378,7 @@ func runMultiModelAgent(ctx context.Context, cfg *multiModelConfig, wfState *sta case state.StatusAgentDone: currentState.ModelStatuses[modelID] = "model-done" currentState.Status = state.StatusWorking - if errSave := saveState(cfg.coord, ¤tState); errSave != nil { + if errSave := persistMultiModelState(cfg.coord, cfg.agent, currentState.ModelStatuses, currentState.CurrentModel, currentState.Status); errSave != nil { return failWorkflowState(cfg, ¤tState, "failed to save state after model done: %v", errSave) } case state.StatusComplete: @@ -387,7 +390,7 @@ func runMultiModelAgent(ctx context.Context, cfg *multiModelConfig, wfState *sta log.Println("["+cfg.paddedsgai+"]", "multi-model consensus reached for", cfg.agent) cleanupModelStatuses(¤tState) currentState.Status = state.StatusAgentDone - if errSave := saveState(cfg.coord, ¤tState); errSave != nil { + if errSave := persistMultiModelState(cfg.coord, cfg.agent, currentState.ModelStatuses, currentState.CurrentModel, currentState.Status); errSave != nil { return failWorkflowState(cfg, ¤tState, "failed to save state after consensus: %v", errSave) } return currentState @@ -395,15 +398,15 @@ func runMultiModelAgent(ctx context.Context, cfg *multiModelConfig, wfState *sta } } -func runSingleModelIteration(ctx context.Context, cfg *multiModelConfig, wfState *state.Workflow, metadata *GoalMetadata, iterationCounter *int, models []string) state.Workflow { +func runSingleModelIteration(ctx context.Context, cfg *multiModelConfig, wfState *state.Workflow, metadata *GoalMetadata, models []string) state.Workflow { modelSpec := "" if len(models) > 0 { modelSpec = models[0] } - return runFlowAgentWithModel(ctx, cfg, wfState, metadata, iterationCounter, modelSpec) + return runFlowAgentWithModel(ctx, cfg, wfState, metadata, modelSpec) } -func runFlowAgentWithModel(ctx context.Context, cfg *multiModelConfig, wfState *state.Workflow, metadata *GoalMetadata, iterationCounter *int, modelSpec string) state.Workflow { +func runFlowAgentWithModel(ctx context.Context, cfg *multiModelConfig, wfState *state.Workflow, metadata *GoalMetadata, modelSpec string) state.Workflow { currentState := *wfState paddedAgentName := cfg.agent + strings.Repeat(" ", max(0, cfg.longestNameLen-len(cfg.agent))) var capturedSessionID string @@ -416,12 +419,12 @@ func runFlowAgentWithModel(ctx context.Context, cfg *multiModelConfig, wfState * return currentState } - *iterationCounter++ - prefix := buildAgentPrefix(cfg.dir, paddedAgentName, *iterationCounter) - - if errSave := saveState(cfg.coord, ¤tState); errSave != nil { - return failWorkflowState(cfg, ¤tState, "failed to save state: %v", errSave) + iteration := 1 + if cfg.nextIteration != nil { + iteration = cfg.nextIteration() } + prefix := buildAgentPrefix(cfg.dir, paddedAgentName, iteration) + if errCopy := copyProjectManagementToRetrospective(cfg.dir, cfg.retrospectiveDir); errCopy != nil { log.Println("failed to copy PROJECT_MANAGEMENT.md to retrospective:", errCopy) } @@ -436,7 +439,7 @@ func runFlowAgentWithModel(ctx context.Context, cfg *multiModelConfig, wfState * } if cfg.retrospectiveDir != "" && capturedSessionID != "" && shouldLogAgent(cfg.dir, baseAgent) { - if errExport := exportAgentSession(cfg, capturedSessionID, *iterationCounter); errExport != nil { + if errExport := exportAgentSession(cfg, capturedSessionID, iteration); errExport != nil { log.Println("failed to export session:", errExport) } } @@ -450,16 +453,10 @@ func runFlowAgentWithModel(ctx context.Context, cfg *multiModelConfig, wfState * return handleCompleteStatus(ctx, cfg, &newState, metadata) case state.StatusAgentDone: - if errSave := saveState(cfg.coord, &newState); errSave != nil { - return failWorkflowState(cfg, &newState, "failed to save state: %v", errSave) - } log.Println("["+cfg.paddedsgai+"]", "agent", cfg.agent, "done:", newState.Task) return newState case state.StatusWorking: - if errSave := saveState(cfg.coord, &newState); errSave != nil { - return failWorkflowState(cfg, &newState, "failed to save state: %v", errSave) - } if agentHasUnreadOutgoingMessages(&newState, cfg.agent) { log.Println("["+cfg.paddedsgai+"]", "agent", cfg.agent, "sent message(s), yielding control...") return newState @@ -480,25 +477,39 @@ func buildAgentPrefix(dir, paddedAgentName string, iteration int) string { } func saveState(coord *state.Coordinator, wfState *state.Workflow) error { + if errReplace := coord.ReplaceState(wfState); errReplace != nil { + return fmt.Errorf("save state: %w", errReplace) + } + return nil +} + +func persistMultiModelState(coord *state.Coordinator, agent string, modelStatuses map[string]string, currentModel, status string) error { if errUpdate := coord.UpdateState(func(wf *state.Workflow) { - *wf = *wfState + wf.ModelStatuses = maps.Clone(modelStatuses) + wf.CurrentModel = currentModel + if status != "" { + updateAgentWorkflowState(wf, agent, status, true, "", false) + } }); errUpdate != nil { - return fmt.Errorf("save state: %w", errUpdate) + return fmt.Errorf("save multi-model state: %w", errUpdate) } return nil } func failWorkflowState(cfg *multiModelConfig, wfState *state.Workflow, format string, args ...any) state.Workflow { - currentState := *wfState message := fmt.Sprintf(format, args...) log.Println(message) - currentState.Status = state.StatusAgentDone - currentState.Task = message - addEnvironmentMessage(¤tState, cfg.agent, message) - if errSave := saveState(cfg.coord, ¤tState); errSave != nil { + if errSave := cfg.coord.UpdateState(func(currentState *state.Workflow) { + updateAgentWorkflowState(currentState, cfg.agent, state.StatusAgentDone, true, message, true) + addEnvironmentMessage(currentState, cfg.agent, message) + }); errSave != nil { log.Println("failed to save workflow failure state:", errSave) + fallbackState := *wfState + updateAgentWorkflowState(&fallbackState, cfg.agent, state.StatusAgentDone, true, message, true) + addEnvironmentMessage(&fallbackState, cfg.agent, message) + return fallbackState } - return currentState + return cfg.coord.State() } func copyProjectManagementToRetrospective(dir, retrospectiveDir string) error { @@ -614,13 +625,12 @@ func executeAgentProcess(ctx context.Context, cfg *multiModelConfig, agentArgs [ stdoutOut := buildAgentOutputWriter(os.Stdout, cfg.logWriter, cfg.stdoutLog) stderrWriter := newPrefixWriter(prefix+" ", stderrOut, time.Now) jsonWriter := newJSONPrettyWriter(prefix+" ", stdoutOut, cfg.coord, cfg.agent, time.Now) - - cfg.coord.ResetAgentDoneWatchdog() + watchdogStarted := false agentEnv, errBuildAgentEnv := buildAgentEnv(cfg, extractModelFromArgs(agentArgs)) if errBuildAgentEnv != nil { log.Println("failed to prepare agent environment:", errBuildAgentEnv) if errUpdate := cfg.coord.UpdateState(func(wf *state.Workflow) { - wf.Status = state.StatusAgentDone + updateAgentWorkflowState(wf, cfg.agent, state.StatusAgentDone, true, "", false) }); errUpdate != nil { log.Println("failed to save state:", errUpdate) } @@ -630,7 +640,21 @@ func executeAgentProcess(ctx context.Context, cfg *multiModelConfig, agentArgs [ } agentCtx, agentCancel := context.WithCancel(ctx) - cfg.coord.SetAgentCancel(agentCancel) + stopAgentDoneWatchdog := func() { + if jsonWriter.stopAgentDoneWatchdog != nil { + jsonWriter.stopAgentDoneWatchdog() + } + } + jsonWriter.startAgentDoneWatchdog = func() { + if watchdogStarted { + return + } + watchdogStarted = true + watchdog := time.AfterFunc(state.AgentDoneWatchdogTimeout, agentCancel) + jsonWriter.stopAgentDoneWatchdog = func() { + watchdog.Stop() + } + } cmd := exec.CommandContext(agentCtx, "opencode", agentArgs...) cmd.Dir = cfg.dir @@ -645,12 +669,13 @@ func executeAgentProcess(ctx context.Context, cfg *multiModelConfig, agentArgs [ if errStart := cmd.Start(); errStart != nil { agentCancel() log.Println("failed to start opencode:", errStart) - result := cfg.coord.State() - result.Status = state.StatusAgentDone - if errSave := saveState(cfg.coord, &result); errSave != nil { + if errSave := cfg.coord.UpdateState(func(wf *state.Workflow) { + updateAgentWorkflowState(wf, cfg.agent, state.StatusAgentDone, true, "", false) + }); errSave != nil { log.Println("failed to save state:", errSave) } log.Println("agent", cfg.agent, "marked as agent-done due to start failure") + result := cfg.coord.State() return zeroState, "", &result } @@ -659,10 +684,10 @@ func executeAgentProcess(ctx context.Context, cfg *multiModelConfig, agentArgs [ errWait := cmd.Wait() close(processExited) - cfg.coord.Stop() agentCancel() flushPrefixWriterWithLog("agent stderr", stderrWriter) jsonWriter.Flush() + stopAgentDoneWatchdog() if errWait != nil { if ctx.Err() != nil { @@ -672,12 +697,13 @@ func executeAgentProcess(ctx context.Context, cfg *multiModelConfig, agentArgs [ log.Println("=== RAW AGENT OUTPUT (last 1000 lines) ===") outputCapture.dump(os.Stderr) log.Println("=== END RAW AGENT OUTPUT ===") - result := cfg.coord.State() - result.Status = state.StatusAgentDone - if errSave := saveState(cfg.coord, &result); errSave != nil { + if errSave := cfg.coord.UpdateState(func(wf *state.Workflow) { + updateAgentWorkflowState(wf, cfg.agent, state.StatusAgentDone, true, "", false) + }); errSave != nil { log.Println("failed to save state:", errSave) } log.Println("agent", cfg.agent, "marked as agent-done due to error", errWait) + result := cfg.coord.State() return zeroState, "", &result } @@ -724,11 +750,12 @@ func exportAgentSession(cfg *multiModelConfig, sessionID string, iteration int) func handleCompleteStatus(ctx context.Context, cfg *multiModelConfig, newState *state.Workflow, metadata *GoalMetadata) state.Workflow { if cfg.agent != "coordinator" { log.Println("["+cfg.paddedsgai+"]", "agent", cfg.agent, "set status=complete, only coordinator can complete workflow; treating as agent-done") - newState.Status = state.StatusAgentDone - if errSave := saveState(cfg.coord, newState); errSave != nil { + if errSave := cfg.coord.UpdateState(func(wf *state.Workflow) { + updateAgentWorkflowState(wf, cfg.agent, state.StatusAgentDone, true, "", false) + }); errSave != nil { return failWorkflowState(cfg, newState, "failed to save state: %v", errSave) } - return *newState + return cfg.coord.State() } if blocked := blockCompletionOnPendingTodos(cfg, newState); blocked != nil { @@ -947,14 +974,14 @@ func addAgentHandoffProgress(wfState *state.Workflow, targetAgent string) { // it as current; otherwise, it appends a new entry. func markCurrentAgentInSequence(s *state.Workflow, currentAgent string) { now := time.Now().UTC().Format(time.RFC3339) + for i := range s.AgentSequence { + s.AgentSequence[i].IsCurrent = false + } lastIdx := len(s.AgentSequence) - 1 if lastIdx >= 0 && s.AgentSequence[lastIdx].Agent == currentAgent { s.AgentSequence[lastIdx].IsCurrent = true return } - for i := range s.AgentSequence { - s.AgentSequence[i].IsCurrent = false - } s.AgentSequence = append(s.AgentSequence, state.AgentSequenceEntry{ Agent: currentAgent, StartTime: now, @@ -962,11 +989,88 @@ func markCurrentAgentInSequence(s *state.Workflow, currentAgent string) { }) } +func markCurrentAgentsInSequence(s *state.Workflow, currentAgents []string) { + if len(currentAgents) == 0 { + return + } + if len(currentAgents) == 1 { + markCurrentAgentInSequence(s, currentAgents[0]) + return + } + + now := time.Now().UTC().Format(time.RFC3339) + for i := range s.AgentSequence { + s.AgentSequence[i].IsCurrent = false + } + for _, currentAgent := range currentAgents { + s.AgentSequence = append(s.AgentSequence, state.AgentSequenceEntry{ + Agent: currentAgent, + StartTime: now, + IsCurrent: true, + }) + } +} + +func splitCurrentAgents(currentAgent string) []string { + if currentAgent == "" { + return nil + } + parts := strings.Split(currentAgent, ", ") + result := make([]string, 0, len(parts)) + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result +} + +func formatCurrentAgents(currentAgents []string) string { + if len(currentAgents) == 0 { + return "" + } + if len(currentAgents) == 1 { + return currentAgents[0] + } + return strings.Join(currentAgents, ", ") +} + +func hasParallelCurrentAgents(currentAgent string) bool { + return len(splitCurrentAgents(currentAgent)) > 1 +} + +func setVisibleAgentTodos(wf *state.Workflow, agent string) { + currentAgents := splitCurrentAgents(wf.CurrentAgent) + if len(currentAgents) == 0 { + if wf.TodosByAgent == nil { + wf.Todos = nil + return + } + wf.Todos = slices.Clone(wf.TodosByAgent[agent]) + return + } + if len(currentAgents) != 1 || currentAgents[0] != agent { + wf.Todos = nil + return + } + if wf.TodosByAgent == nil { + wf.Todos = nil + return + } + wf.Todos = slices.Clone(wf.TodosByAgent[agent]) +} + // todosForAgent returns the TODO list enforced for the given agent. func todosForAgent(wf *state.Workflow, agent string) []state.TodoItem { if agent == "coordinator" { return wf.ProjectTodos } + if wf.TodosByAgent != nil { + if todos, found := wf.TodosByAgent[agent]; found { + return todos + } + } return wf.Todos } @@ -1310,33 +1414,6 @@ func parseYAMLFrontmatter(content []byte) (GoalMetadata, error) { //go:embed skel/** var skelFS embed.FS -func findFirstPendingMessageAgent(messages []state.Message) string { - if len(messages) == 0 { - return "" - } - for _, msg := range messages { - if !msg.Read { - return extractAgentFromModelID(msg.ToAgent) - } - } - return "" -} - -func redirectToPendingMessageAgent(s *state.Workflow, coord *state.Coordinator, paddedsgai string) (bool, error) { - pendingAgent := findFirstPendingMessageAgent(s.Messages) - if pendingAgent == "" { - return false, nil - } - log.Println("["+paddedsgai+"]", "pending messages for", pendingAgent, "- redirecting before completion") - s.Status = state.StatusWorking - s.CurrentAgent = pendingAgent - s.VisitCounts[pendingAgent]++ - if errSave := saveState(coord, s); errSave != nil { - return false, fmt.Errorf("save state while redirecting to pending message agent: %w", errSave) - } - return true, nil -} - func runCompletionGateScript(ctx context.Context, dir, script string) (string, error) { cmd := exec.CommandContext(ctx, "sh", "-c", script) cmd.Dir = dir @@ -1630,28 +1707,32 @@ type toolState struct { } type jsonPrettyWriter struct { - prefix string - w io.Writer - buf []byte - currentText strings.Builder - sessionID string - coord *state.Coordinator - currentAgent string - stepCounter int - now func() time.Time + prefix string + w io.Writer + buf []byte + currentText strings.Builder + sessionID string + coord *state.Coordinator + currentAgent string + stepCounter int + now func() time.Time + startAgentDoneWatchdog func() + stopAgentDoneWatchdog func() } func newJSONPrettyWriter(prefix string, w io.Writer, coord *state.Coordinator, currentAgent string, now func() time.Time) *jsonPrettyWriter { return &jsonPrettyWriter{ - prefix: prefix, - w: w, - buf: nil, - currentText: strings.Builder{}, - sessionID: "", - coord: coord, - currentAgent: currentAgent, - stepCounter: 0, - now: now, + prefix: prefix, + w: w, + buf: nil, + currentText: strings.Builder{}, + sessionID: "", + coord: coord, + currentAgent: currentAgent, + stepCounter: 0, + now: now, + startAgentDoneWatchdog: nil, + stopAgentDoneWatchdog: nil, } } @@ -1730,6 +1811,9 @@ func (j *jsonPrettyWriter) processEvent(event *streamEvent) { if _, err := fmt.Fprintln(j.w, j.tsPrefix()+toolCall); err != nil { log.Println("write failed:", err) } + if part.Tool == "update_workflow_state" && toolRequestsAgentDone(part.State.Input) && j.startAgentDoneWatchdog != nil { + j.startAgentDoneWatchdog() + } if part.State.Output != "" { if isTodoTool(part.Tool) { j.formatTodoOutput(part.State.Output) @@ -1871,6 +1955,21 @@ func isTodoTool(tool string) bool { } } +func toolRequestsAgentDone(input map[string]any) bool { + if input == nil { + return false + } + statusValue, found := input["status"] + if !found { + return false + } + status, ok := statusValue.(string) + if !ok { + return false + } + return status == state.StatusAgentDone +} + func (j *jsonPrettyWriter) formatTodoOutput(output string) { todos, ok := parseTodoOutput(output) if !ok { @@ -1901,7 +2000,11 @@ func (j *jsonPrettyWriter) updateAgentTodos(tool, output string) { } if errUpdate := j.coord.UpdateState(func(wf *state.Workflow) { - wf.Todos = slices.Clone(todos) + if wf.TodosByAgent == nil { + wf.TodosByAgent = make(map[string][]state.TodoItem) + } + wf.TodosByAgent[j.currentAgent] = slices.Clone(todos) + setVisibleAgentTodos(wf, j.currentAgent) }); errUpdate != nil { log.Println("failed to save todos:", errUpdate) } diff --git a/cmd/sgai/main_test.go b/cmd/sgai/main_test.go index 561c1f9..8cbf072 100644 --- a/cmd/sgai/main_test.go +++ b/cmd/sgai/main_test.go @@ -42,6 +42,22 @@ func newTestDag(nodeNames ...string) *dag { } } +func buildTestDag(edges map[string][]string, entryNodes []string) *dag { + d := &dag{ + Nodes: make(map[string]*dagNode), + EntryNodes: entryNodes, + } + for from, toList := range edges { + node := d.ensureNode(from) + for _, to := range toList { + toNode := d.ensureNode(to) + node.Successors = append(node.Successors, to) + toNode.Predecessors = append(toNode.Predecessors, from) + } + } + return d +} + func newTestGoalMetadata() GoalMetadata { return GoalMetadata{ Title: "", @@ -103,11 +119,14 @@ func newTestWorkflow() state.Workflow { Task: "", Progress: nil, HumanMessage: "", + HumanInputAgent: "", MultiChoiceQuestion: nil, Messages: nil, VisitCounts: nil, CurrentAgent: "", + AgentStates: nil, Todos: nil, + TodosByAgent: nil, ProjectTodos: nil, AgentSequence: nil, SessionID: "", @@ -258,6 +277,7 @@ func newTestMultiModelConfig() multiModelConfig { logWriter: nil, stdoutLog: nil, stderrLog: nil, + nextIteration: nil, } } @@ -320,15 +340,17 @@ func newTestStreamEvent() streamEvent { func newTestJSONPrettyWriter() *jsonPrettyWriter { return &jsonPrettyWriter{ - prefix: "", - w: nil, - buf: nil, - currentText: strings.Builder{}, - sessionID: "", - coord: nil, - currentAgent: "", - stepCounter: 0, - now: testLogNow, + prefix: "", + w: nil, + buf: nil, + currentText: strings.Builder{}, + sessionID: "", + coord: nil, + currentAgent: "", + stepCounter: 0, + now: testLogNow, + startAgentDoneWatchdog: nil, + stopAgentDoneWatchdog: nil, } } @@ -976,55 +998,6 @@ func TestHandleCompleteStatus(t *testing.T) { }) } -func TestRedirectToPendingMessageAgent(t *testing.T) { - t.Run("noMessages", func(t *testing.T) { - dir := t.TempDir() - coord, err := state.NewCoordinatorWith(filepath.Join(dir, "state.json"), newTestWorkflow()) - require.NoError(t, err) - - wfState := newTestWorkflow() - result, errRedirect := redirectToPendingMessageAgent(&wfState, coord, "sgai") - require.NoError(t, errRedirect) - assert.False(t, result) - }) - - t.Run("allMessagesRead", func(t *testing.T) { - dir := t.TempDir() - coord, err := state.NewCoordinatorWith(filepath.Join(dir, "state.json"), newTestWorkflow()) - require.NoError(t, err) - - wfState := newTestWorkflow() - message := newTestMessage() - message.ID = 1 - message.ToAgent = "dev" - message.Read = true - wfState.Messages = []state.Message{message} - result, errRedirect := redirectToPendingMessageAgent(&wfState, coord, "sgai") - require.NoError(t, errRedirect) - assert.False(t, result) - }) - - t.Run("unreadMessageRedirects", func(t *testing.T) { - dir := t.TempDir() - statePath := filepath.Join(dir, "state.json") - coord, err := state.NewCoordinatorWith(statePath, newTestWorkflow()) - require.NoError(t, err) - - wfState := newTestWorkflow() - wfState.VisitCounts = map[string]int{} - message := newTestMessage() - message.ID = 1 - message.ToAgent = "developer" - message.Read = false - wfState.Messages = []state.Message{message} - result, errRedirect := redirectToPendingMessageAgent(&wfState, coord, "sgai") - require.NoError(t, errRedirect) - assert.True(t, result) - assert.Equal(t, "developer", wfState.CurrentAgent) - assert.Equal(t, state.StatusWorking, wfState.Status) - }) -} - func TestBuildAgentArgsVariants(t *testing.T) { cases := []struct { name string @@ -1445,31 +1418,6 @@ func TestResolveBaseAgentWithAlias(t *testing.T) { assert.Equal(t, "builder", resolveBaseAgent(alias, "builder")) } -func TestFindFirstPendingMessageAgentVariants(t *testing.T) { - t.Run("noMessages", func(t *testing.T) { - assert.Empty(t, findFirstPendingMessageAgent(newTestWorkflow().Messages)) - }) - - t.Run("allRead", func(t *testing.T) { - wf := newTestWorkflow() - message := newTestMessage() - message.ToAgent = "builder" - message.Read = true - wf.Messages = []state.Message{message} - assert.Empty(t, findFirstPendingMessageAgent(wf.Messages)) - }) - - t.Run("unreadForAgent", func(t *testing.T) { - wf := newTestWorkflow() - message := newTestMessage() - message.ToAgent = "builder" - message.Read = false - wf.Messages = []state.Message{message} - wf.CurrentAgent = "coordinator" - assert.Equal(t, "builder", findFirstPendingMessageAgent(wf.Messages)) - }) -} - func TestValidateModelsPartial(t *testing.T) { t.Run("emptyModels", func(t *testing.T) { err := validateModels(nil) @@ -1534,6 +1482,36 @@ func TestSaveState(t *testing.T) { assert.Equal(t, state.StatusComplete, updated.Status) } +func TestSaveStateDetachesCallerOwnedReferenceFields(t *testing.T) { + dir := t.TempDir() + statePath := filepath.Join(dir, ".sgai", "state.json") + require.NoError(t, os.MkdirAll(filepath.Dir(statePath), 0o755)) + + coord, errCoord := state.NewCoordinatorWith(statePath, state.NewWorkflow()) + require.NoError(t, errCoord) + + wf := state.NewWorkflow() + wf.Status = state.StatusWorking + wf.Task = "stable task" + wf.Progress = []state.ProgressEntry{{Timestamp: "", Agent: "", Description: "stable progress"}} + wf.Messages = []state.Message{updated(newTestMessage(), func(message *state.Message) { + message.ID = 1 + message.Body = "stable message" + })} + wf.VisitCounts["coordinator"] = 1 + wf.TodosByAgent["go-developer"] = []state.TodoItem{{ID: "todo-1", Content: "stable todo", Status: "pending", Priority: "high"}} + + require.NoError(t, saveState(coord, &wf)) + saved := coord.State() + + wf.Progress[0].Description = "mutated progress" + wf.Messages[0].Body = "mutated message" + wf.VisitCounts["coordinator"] = 99 + wf.TodosByAgent["go-developer"][0].Content = "mutated todo" + + assert.Equal(t, saved, coord.State()) +} + func TestSaveStateReturnsErrorOnPersistFailure(t *testing.T) { dir := t.TempDir() blockingPath := filepath.Join(dir, "blocking-file") @@ -1961,64 +1939,6 @@ func TestRetrospectiveEnabled(t *testing.T) { } } -func TestFindFirstPendingMessageAgent(t *testing.T) { - tests := []struct { - name string - workflow state.Workflow - expected string - }{ - { - name: "noMessages", - workflow: func() state.Workflow { - workflow := newTestWorkflow() - workflow.Messages = []state.Message{} - return workflow - }(), - expected: "", - }, - { - name: "allRead", - workflow: func() state.Workflow { - workflow := newTestWorkflow() - messageOne := newTestMessage() - messageOne.ToAgent = "agent1" - messageOne.Read = true - messageTwo := newTestMessage() - messageTwo.ToAgent = "agent2" - messageTwo.Read = true - workflow.Messages = []state.Message{messageOne, messageTwo} - return workflow - }(), - expected: "", - }, - { - name: "firstUnread", - workflow: func() state.Workflow { - workflow := newTestWorkflow() - messageOne := newTestMessage() - messageOne.ToAgent = "agent1" - messageOne.Read = true - messageTwo := newTestMessage() - messageTwo.ToAgent = "agent2" - messageTwo.Read = false - messageThree := newTestMessage() - messageThree.ToAgent = "agent3" - messageThree.Read = false - workflow.Messages = []state.Message{messageOne, messageTwo, messageThree} - return workflow - }(), - expected: "agent2", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := findFirstPendingMessageAgent(tt.workflow.Messages) - assert.Equal(t, tt.expected, result) - }) - } -} - func TestExtractFrontmatterDescription(t *testing.T) { tests := []struct { name string @@ -3605,6 +3525,36 @@ func TestMarkCurrentAgentInSequence(t *testing.T) { } } +func TestMarkCurrentAgentsInSequence(t *testing.T) { + wf := newTestWorkflow() + markCurrentAgentsInSequence(&wf, []string{"go-developer", "react-developer"}) + require.Len(t, wf.AgentSequence, 2) + assert.Equal(t, "go-developer", wf.AgentSequence[0].Agent) + assert.True(t, wf.AgentSequence[0].IsCurrent) + assert.Equal(t, "react-developer", wf.AgentSequence[1].Agent) + assert.True(t, wf.AgentSequence[1].IsCurrent) +} + +func TestMarkCurrentAgentsInSequenceClearsStaleParallelCurrentFlags(t *testing.T) { + wf := newTestWorkflow() + wf.AgentSequence = []state.AgentSequenceEntry{ + updated(newTestAgentSequenceEntry(), func(entry *state.AgentSequenceEntry) { + entry.Agent = "go-developer" + entry.IsCurrent = true + }), + updated(newTestAgentSequenceEntry(), func(entry *state.AgentSequenceEntry) { + entry.Agent = "react-developer" + entry.IsCurrent = true + }), + } + + markCurrentAgentsInSequence(&wf, []string{"react-developer"}) + + require.Len(t, wf.AgentSequence, 2) + assert.False(t, wf.AgentSequence[0].IsCurrent) + assert.True(t, wf.AgentSequence[1].IsCurrent) +} + func TestAddAgentHandoffProgress(t *testing.T) { wf := newTestWorkflow() wf.Progress = []state.ProgressEntry{} diff --git a/cmd/sgai/mcp.go b/cmd/sgai/mcp.go index fa84ecd..26beddf 100644 --- a/cmd/sgai/mcp.go +++ b/cmd/sgai/mcp.go @@ -323,6 +323,51 @@ func parseAgentIdentityHeader(r *http.Request) string { return name } +type callerContext struct { + agentName string + modelID string +} + +func parseCallerContext(r *http.Request, coord *state.Coordinator) callerContext { + identity := r.Header.Get(agentIdentityHeader) + agentName := resolveCallerAgent(parseAgentIdentityHeader(r), coord) + modelID := parseModelIDFromAgentIdentity(identity) + if modelID == "" { + modelID = currentModelForCaller(coord, agentName) + } + return callerContext{agentName: agentName, modelID: modelID} +} + +func parseModelIDFromAgentIdentity(identity string) string { + if identity == "" { + return "" + } + agentName, rest, found := strings.Cut(identity, "|") + if !found || agentName == "" { + return "" + } + model, variant, _ := strings.Cut(rest, "|") + if model == "" { + return "" + } + modelSpec := model + if variant != "" { + modelSpec = model + " (" + variant + ")" + } + return formatModelID(agentName, modelSpec) +} + +func currentModelForCaller(coord *state.Coordinator, callerAgent string) string { + if coord == nil || callerAgent == "" { + return "" + } + currentModel := coord.State().CurrentModel + if extractAgentFromModelID(currentModel) != callerAgent { + return "" + } + return currentModel +} + func resolveCallerAgent(headerAgent string, coord *state.Coordinator) string { if headerAgent == "" { if currentAgent := coord.State().CurrentAgent; currentAgent != "" && currentAgent != "coordinator" { @@ -357,19 +402,22 @@ func buildMCPHTTPHandler(workingDir string, coord *state.Coordinator, dagAgents } func buildMCPServer(workingDir string, r *http.Request, coord *state.Coordinator, dagAgents []string, humanTools humanToolCallbacks) (*mcp.Server, error) { - agentName := resolveCallerAgent(parseAgentIdentityHeader(r), coord) - return buildMCPServerForAgent(workingDir, coord, dagAgents, humanTools, agentName) + return buildMCPServerForCaller(workingDir, coord, dagAgents, humanTools, parseCallerContext(r, coord)) } func buildMCPServerForAgent(workingDir string, coord *state.Coordinator, dagAgents []string, humanTools humanToolCallbacks, agentName string) (*mcp.Server, error) { + return buildMCPServerForCaller(workingDir, coord, dagAgents, humanTools, callerContext{agentName: agentName, modelID: currentModelForCaller(coord, agentName)}) +} + +func buildMCPServerForCaller(workingDir string, coord *state.Coordinator, dagAgents []string, humanTools humanToolCallbacks, caller callerContext) (*mcp.Server, error) { server := mcp.NewServer(newMCPImplementation("sgai"), nil) - mcpCtx := &mcpContext{workingDir: workingDir, coord: coord, dagAgents: dagAgents, agentName: agentName, humanTools: humanTools} + mcpCtx := &mcpContext{workingDir: workingDir, coord: coord, dagAgents: dagAgents, agentName: caller.agentName, modelID: caller.modelID, humanTools: humanTools} - if errRegister := registerCommonTools(server, mcpCtx, agentName); errRegister != nil { + if errRegister := registerCommonTools(server, mcpCtx, caller.agentName); errRegister != nil { return nil, errRegister } - if agentName == "coordinator" { + if caller.agentName == "coordinator" { if errRegister := registerCoordinatorTools(server, mcpCtx); errRegister != nil { return nil, errRegister } @@ -490,6 +538,7 @@ type mcpContext struct { coord *state.Coordinator dagAgents []string agentName string + modelID string humanTools humanToolCallbacks } @@ -520,7 +569,7 @@ func (c *mcpContext) updateWorkflowStateHandler(_ context.Context, _ *mcp.CallTo } func (c *mcpContext) sendMessageHandler(_ context.Context, _ *mcp.CallToolRequest, args sendMessageArgs) (*mcp.CallToolResult, emptyResult, error) { - result, err := sendMessage(c.workingDir, c.coord, c.dagAgents, c.agentName, args.ToAgent, args.Body) + result, err := sendMessageForCaller(c.workingDir, c.coord, c.dagAgents, callerContext{agentName: c.agentName, modelID: c.modelID}, args.ToAgent, args.Body) if err != nil { return nil, emptyResult{}, err } @@ -528,7 +577,7 @@ func (c *mcpContext) sendMessageHandler(_ context.Context, _ *mcp.CallToolReques } func (c *mcpContext) checkInboxHandler(_ context.Context, _ *mcp.CallToolRequest, _ struct{}) (*mcp.CallToolResult, emptyResult, error) { - result, err := checkInbox(c.coord, c.agentName) + result, err := checkInboxForCaller(c.coord, callerContext{agentName: c.agentName, modelID: c.modelID}) if err != nil { return nil, emptyResult{}, err } @@ -536,7 +585,7 @@ func (c *mcpContext) checkInboxHandler(_ context.Context, _ *mcp.CallToolRequest } func (c *mcpContext) checkOutboxHandler(_ context.Context, _ *mcp.CallToolRequest, _ struct{}) (*mcp.CallToolResult, emptyResult, error) { - result, err := checkOutbox(c.coord, c.agentName) + result, err := checkOutboxForCaller(c.coord, callerContext{agentName: c.agentName, modelID: c.modelID}) if err != nil { return nil, emptyResult{}, err } @@ -576,7 +625,16 @@ func (c *mcpContext) projectTodoReadHandler(_ context.Context, _ *mcp.CallToolRe } func (c *mcpContext) askUserQuestionHandler(ctx context.Context, _ *mcp.CallToolRequest, args askUserQuestionArgs) (*mcp.CallToolResult, emptyResult, error) { - result, err := c.askUserQuestionResponder()(ctx, c.coord, args) + responder := c.askUserQuestionResponder() + var ( + result string + err error + ) + if responder != nil { + result, err = responder(ctx, c.coord, args) + } else { + result, err = askUserQuestionForAgent(ctx, c.coord, c.agentName, args) + } if err != nil { return nil, emptyResult{}, err } @@ -584,7 +642,16 @@ func (c *mcpContext) askUserQuestionHandler(ctx context.Context, _ *mcp.CallTool } func (c *mcpContext) askUserWorkGateHandler(ctx context.Context, _ *mcp.CallToolRequest, args askUserWorkGateArgs) (*mcp.CallToolResult, emptyResult, error) { - result, err := c.askUserWorkGateResponder()(ctx, c.coord, args.Summary) + responder := c.askUserWorkGateResponder() + var ( + result string + err error + ) + if responder != nil { + result, err = responder(ctx, c.coord, args.Summary) + } else { + result, err = askUserWorkGateForAgent(ctx, c.coord, c.agentName, args.Summary) + } if err != nil { return nil, emptyResult{}, err } @@ -592,20 +659,18 @@ func (c *mcpContext) askUserWorkGateHandler(ctx context.Context, _ *mcp.CallTool } func (c *mcpContext) askUserQuestionResponder() askUserQuestionFunc { - if c.humanTools.question != nil { - return c.humanTools.question - } - return askUserQuestion + return c.humanTools.question } func (c *mcpContext) askUserWorkGateResponder() askUserWorkGateFunc { - if c.humanTools.workGate != nil { - return c.humanTools.workGate - } - return askUserWorkGate + return c.humanTools.workGate } func askUserQuestion(ctx context.Context, coord *state.Coordinator, args askUserQuestionArgs) (string, error) { + return askUserQuestionForAgent(ctx, coord, "", args) +} + +func askUserQuestionForAgent(ctx context.Context, coord *state.Coordinator, askingAgent string, args askUserQuestionArgs) (string, error) { if coord == nil { return "Error: workflow coordinator not available.", nil } @@ -615,10 +680,10 @@ func askUserQuestion(ctx context.Context, coord *state.Coordinator, args askUser return askUserQuestionAutoResponse(autoProceedAnswer)(ctx, coord, args) } - return askUserQuestionInteractive(ctx, coord, args) + return askUserQuestionInteractive(ctx, coord, askingAgent, args) } -func askUserQuestionInteractive(ctx context.Context, coord *state.Coordinator, args askUserQuestionArgs) (string, error) { +func askUserQuestionInteractive(ctx context.Context, coord *state.Coordinator, askingAgent string, args askUserQuestionArgs) (string, error) { if coord == nil { return "Error: workflow coordinator not available.", nil } @@ -628,7 +693,7 @@ func askUserQuestionInteractive(ctx context.Context, coord *state.Coordinator, a } question, humanMessage, questionSummary := buildQuestionRequest(args) - answer, errWait := waitForHumanResponse(ctx, coord, question, humanMessage, "ask_user_question") + answer, errWait := waitForHumanResponse(ctx, coord, question, humanMessage, normalizeHumanInputAgent(coord, askingAgent), "ask_user_question") if errWait != nil { return "", fmt.Errorf("waiting for human response: %w", errWait) } @@ -684,6 +749,10 @@ func buildQuestionRequest(args askUserQuestionArgs) (question *state.MultiChoice } func askUserWorkGate(ctx context.Context, coord *state.Coordinator, summary string) (string, error) { + return askUserWorkGateForAgent(ctx, coord, "", summary) +} + +func askUserWorkGateForAgent(ctx context.Context, coord *state.Coordinator, askingAgent, summary string) (string, error) { if validationErr := validateAskUserWorkGateSummary(summary); validationErr != "" { return validationErr, nil } @@ -697,10 +766,10 @@ func askUserWorkGate(ctx context.Context, coord *state.Coordinator, summary stri return askUserWorkGateAutoResponse(autoRecordQuestionsAnswer)(ctx, coord, summary) } - return askUserWorkGateInteractive(ctx, coord, summary) + return askUserWorkGateInteractive(ctx, coord, normalizeHumanInputAgent(coord, askingAgent), summary) } -func askUserWorkGateInteractive(ctx context.Context, coord *state.Coordinator, summary string) (string, error) { +func askUserWorkGateInteractive(ctx context.Context, coord *state.Coordinator, askingAgent, summary string) (string, error) { if coord == nil { return "Error: workflow coordinator not available.", nil } @@ -722,7 +791,7 @@ func askUserWorkGateInteractive(ctx context.Context, coord *state.Coordinator, s IsWorkGate: true, } - answer, errWait := waitForHumanResponse(ctx, coord, question, questionText, "ask_user_work_gate") + answer, errWait := waitForHumanResponse(ctx, coord, question, questionText, askingAgent, "ask_user_work_gate") if errWait != nil { return "", fmt.Errorf("waiting for human response: %w", errWait) } @@ -749,7 +818,21 @@ func validateAskUserWorkGateSummary(summary string) string { return "" } -func waitForHumanResponse(ctx context.Context, coord *state.Coordinator, question *state.MultiChoiceQuestion, humanMessage, toolName string) (string, error) { +func normalizeHumanInputAgent(coord *state.Coordinator, askingAgent string) string { + if strings.TrimSpace(askingAgent) != "" { + return askingAgent + } + if coord == nil { + return "" + } + currentAgent := coord.State().CurrentAgent + if hasParallelCurrentAgents(currentAgent) { + return "" + } + return currentAgent +} + +func waitForHumanResponse(ctx context.Context, coord *state.Coordinator, question *state.MultiChoiceQuestion, humanMessage, askingAgent, toolName string) (string, error) { if coord == nil { return "", errors.New("workflow coordinator not available") } @@ -757,7 +840,7 @@ func waitForHumanResponse(ctx context.Context, coord *state.Coordinator, questio ctxWait, cancel := context.WithTimeout(ctx, humanToolTimeout) defer cancel() - answer, errWait := coord.AskAndWait(ctxWait, question, humanMessage) + answer, errWait := coord.AskAndWait(ctxWait, question, humanMessage, askingAgent) if errWait == nil { return answer, nil } @@ -790,7 +873,7 @@ func promoteAfterWorkGate(coord *state.Coordinator) error { func selectHumanToolCallbacks(workingDir string, coord *state.Coordinator) humanToolCallbacks { if coord == nil { - return humanToolCallbacks{question: askUserQuestion, workGate: askUserWorkGate} + return humanToolCallbacks{question: nil, workGate: nil} } switch coord.State().InteractionMode { @@ -809,7 +892,7 @@ func selectHumanToolCallbacks(workingDir string, coord *state.Coordinator) human } } - return humanToolCallbacks{question: askUserQuestion, workGate: askUserWorkGate} + return humanToolCallbacks{question: nil, workGate: nil} } func readGoalMetadata(workingDir string) GoalMetadata { @@ -1151,7 +1234,6 @@ func findSnippetsByFuzzyMatch(langDir string, entries []os.DirEntry, query strin func updateWorkflowState(coord *state.Coordinator, callerAgent string, args updateWorkflowStateArgs) (string, error) { var response string - var shouldStartWatchdog bool if coord == nil { return "Error: workflow coordinator not available.", nil @@ -1182,12 +1264,7 @@ func updateWorkflowState(coord *state.Coordinator, callerAgent string, args upda currentState.Progress = []state.ProgressEntry{} } - if args.Status != "" { - currentState.Status = nextStatus - shouldStartWatchdog = nextStatus == state.StatusAgentDone - } - - currentState.Task = args.Task + updateAgentWorkflowState(currentState, callerAgent, nextStatus, args.Status != "", args.Task, true) if args.AddProgress != "" { entry := state.ProgressEntry{ @@ -1198,15 +1275,11 @@ func updateWorkflowState(coord *state.Coordinator, callerAgent string, args upda currentState.Progress = append(currentState.Progress, entry) } - if (currentState.Status == state.StatusComplete || currentState.Status == state.StatusAgentDone) && currentState.Task != "" { - currentState.Task = "" - } - if response == "" { response = "State updated successfully.\n" - response += fmt.Sprintf(" Status: %s\n", currentState.Status) - if currentState.Task != "" { - response += fmt.Sprintf(" Current task: %s\n", currentState.Task) + response += fmt.Sprintf(" Status: %s\n", visibleWorkflowStatus(currentState)) + if task := visibleWorkflowTask(currentState); task != "" { + response += fmt.Sprintf(" Current task: %s\n", task) } if args.AddProgress != "" { response += fmt.Sprintf(" Added progress note: %s\n", args.AddProgress) @@ -1223,14 +1296,14 @@ func updateWorkflowState(coord *state.Coordinator, callerAgent string, args upda return response, nil } - if shouldStartWatchdog { - coord.StartAgentDoneWatchdog(coord.AgentCancel()) - } - return response, nil } func sendMessage(workingDir string, coord *state.Coordinator, dagAgents []string, callerAgent, toAgent, body string) (string, error) { + return sendMessageForCaller(workingDir, coord, dagAgents, callerContext{agentName: callerAgent, modelID: currentModelForCaller(coord, callerAgent)}, toAgent, body) +} + +func sendMessageForCaller(workingDir string, coord *state.Coordinator, dagAgents []string, caller callerContext, toAgent, body string) (string, error) { if coord == nil { return "Error: Could not read state.json. Has the workflow been initialized?", nil } @@ -1251,9 +1324,9 @@ func sendMessage(workingDir string, coord *state.Coordinator, dagAgents []string currentState.Messages = []state.Message{} } - fromAgent = callerAgent - if currentState.CurrentModel != "" { - fromAgent = currentState.CurrentModel + fromAgent = caller.agentName + if caller.modelID != "" { + fromAgent = caller.modelID } createdAt := time.Now().UTC().Format(time.RFC3339) @@ -1267,7 +1340,7 @@ func sendMessage(workingDir string, coord *state.Coordinator, dagAgents []string } else { result = fmt.Sprintf("Sent %d messages successfully to %s.\nFrom: %s\nTo: %s\nBody: %s", len(recipients), toAgent, fromAgent, strings.Join(recipients, ", "), body) } - if callerAgent != "coordinator" { + if caller.agentName != "coordinator" { result += "\n\nIMPORTANT: To receive a response from the target agent, you MUST yield control by calling sgai_update_workflow_state({status: 'agent-done'}). The target agent cannot run until you yield." } }) @@ -1293,16 +1366,20 @@ func newStateMessage(id int, fromAgent, toAgent, body, createdAt string) state.M } func checkInbox(coord *state.Coordinator, callerAgent string) (string, error) { + return checkInboxForCaller(coord, callerContext{agentName: callerAgent, modelID: currentModelForCaller(coord, callerAgent)}) +} + +func checkInboxForCaller(coord *state.Coordinator, caller callerContext) (string, error) { if coord == nil { return "Error: Could not read state.json. Has the workflow been initialized?", nil } snapshot := coord.State() - currentModel := snapshot.CurrentModel + currentModel := caller.modelID var unreadMessages []state.Message for i := range snapshot.Messages { - if messageMatchesRecipient(&snapshot.Messages[i], callerAgent, currentModel) && !snapshot.Messages[i].Read { + if messageMatchesRecipient(&snapshot.Messages[i], caller.agentName, currentModel) && !snapshot.Messages[i].Read { unreadMessages = append(unreadMessages, snapshot.Messages[i]) } } @@ -1314,10 +1391,13 @@ func checkInbox(coord *state.Coordinator, callerAgent string) (string, error) { timestamp := time.Now().Format(time.RFC3339) errUpdate := coord.UpdateState(func(wf *state.Workflow) { for i := range wf.Messages { - if messageMatchesRecipient(&wf.Messages[i], callerAgent, currentModel) && !wf.Messages[i].Read { + if messageMatchesRecipient(&wf.Messages[i], caller.agentName, currentModel) && !wf.Messages[i].Read { wf.Messages[i].Read = true wf.Messages[i].ReadAt = timestamp - wf.Messages[i].ReadBy = callerAgent + wf.Messages[i].ReadBy = caller.agentName + if caller.modelID != "" { + wf.Messages[i].ReadBy = caller.modelID + } } } }) @@ -1335,19 +1415,22 @@ func checkInbox(coord *state.Coordinator, callerAgent string) (string, error) { return strings.TrimSpace(result.String()), nil } -//nolint:unparam // error is always nil by design - errors are handled by returning user-friendly messages func checkOutbox(coord *state.Coordinator, callerAgent string) (string, error) { + return checkOutboxForCaller(coord, callerContext{agentName: callerAgent, modelID: currentModelForCaller(coord, callerAgent)}) +} + +func checkOutboxForCaller(coord *state.Coordinator, caller callerContext) (string, error) { if coord == nil { return "Error: Could not read state.json. Has the workflow been initialized?", nil } snapshot := coord.State() - currentModel := snapshot.CurrentModel + currentModel := caller.modelID var unreadMessages []state.Message var readMessages []state.Message for i := range snapshot.Messages { - if messageMatchesSender(&snapshot.Messages[i], callerAgent, currentModel) { + if messageMatchesSender(&snapshot.Messages[i], caller.agentName, currentModel) { if snapshot.Messages[i].Read { readMessages = append(readMessages, snapshot.Messages[i]) } else { diff --git a/cmd/sgai/mcp_test.go b/cmd/sgai/mcp_test.go index 9f17223..ee6d7e0 100644 --- a/cmd/sgai/mcp_test.go +++ b/cmd/sgai/mcp_test.go @@ -7,6 +7,7 @@ import ( "path/filepath" "slices" "testing" + "testing/synctest" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/assert" @@ -125,6 +126,7 @@ func TestMCPHandlerErrorPaths(t *testing.T) { coord: nil, dagAgents: nil, agentName: "test", + modelID: "", humanTools: newTestHumanToolCallbacks(), } _, _, err := ctx.findSkillsHandler(context.Background(), nil, findSkillsArgs{Name: "exact-match"}) @@ -137,6 +139,7 @@ func TestMCPHandlerErrorPaths(t *testing.T) { coord: nil, dagAgents: nil, agentName: "test", + modelID: "", humanTools: newTestHumanToolCallbacks(), } result, _, err := ctx.findSnippetsHandler(context.Background(), nil, newTestFindSnippetsArgs("go", "")) @@ -198,6 +201,7 @@ func newTestMCPContextForAgent(t *testing.T, agentName string) (ctx *mcpContext, coord: coord, dagAgents: []string{"coordinator", agentName, "reviewer"}, agentName: agentName, + modelID: "", humanTools: newTestHumanToolCallbacks(), } return ctx, dir @@ -1833,7 +1837,7 @@ func TestRegisterCommonToolsInternal(t *testing.T) { coord, errCoord := state.NewCoordinatorWith(stateFile, newTestWorkflow()) require.NoError(t, errCoord) server := mcp.NewServer(newMCPImplementation("test"), nil) - mcpCtx := &mcpContext{workingDir: t.TempDir(), coord: coord, dagAgents: []string{"builder"}, agentName: "builder", humanTools: newTestHumanToolCallbacks()} + mcpCtx := &mcpContext{workingDir: t.TempDir(), coord: coord, dagAgents: []string{"builder"}, agentName: "builder", modelID: "", humanTools: newTestHumanToolCallbacks()} require.NoError(t, registerCommonTools(server, mcpCtx, "builder")) assert.NotNil(t, server) } @@ -1843,7 +1847,7 @@ func TestRegisterCoordinatorToolsInternal(t *testing.T) { coord, errCoord := state.NewCoordinatorWith(stateFile, newTestWorkflow()) require.NoError(t, errCoord) server := mcp.NewServer(newMCPImplementation("test"), nil) - mcpCtx := &mcpContext{workingDir: t.TempDir(), coord: coord, dagAgents: []string{"coordinator"}, agentName: "coordinator", humanTools: newTestHumanToolCallbacks()} + mcpCtx := &mcpContext{workingDir: t.TempDir(), coord: coord, dagAgents: []string{"coordinator"}, agentName: "coordinator", modelID: "", humanTools: newTestHumanToolCallbacks()} require.NoError(t, registerCoordinatorTools(server, mcpCtx)) assert.NotNil(t, server) } @@ -1855,7 +1859,7 @@ func TestRegisterCoordinatorToolsBrainstormingMode(t *testing.T) { })) require.NoError(t, errCoord) server := mcp.NewServer(newMCPImplementation("test"), nil) - mcpCtx := &mcpContext{workingDir: t.TempDir(), coord: coord, dagAgents: []string{"coordinator"}, agentName: "coordinator", humanTools: newTestHumanToolCallbacks()} + mcpCtx := &mcpContext{workingDir: t.TempDir(), coord: coord, dagAgents: []string{"coordinator"}, agentName: "coordinator", modelID: "", humanTools: newTestHumanToolCallbacks()} require.NoError(t, registerCoordinatorTools(server, mcpCtx)) assert.NotNil(t, server) } @@ -2070,6 +2074,92 @@ func TestAskUserWorkGateWithValidCoordinator(t *testing.T) { assert.Contains(t, result, "DEFINITION IS COMPLETE") } +func TestHumanInputToolsQueueParallelPrompts(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + stateFile := filepath.Join(t.TempDir(), "state.json") + coord, err := state.NewCoordinatorWith(stateFile, updated(newTestWorkflow(), func(workflow *state.Workflow) { + workflow.Status = state.StatusWorking + workflow.InteractionMode = state.ModeBrainstorming + workflow.CurrentAgent = "go-developer, react-developer" + })) + require.NoError(t, err) + + goCtx := &mcpContext{ + workingDir: t.TempDir(), + coord: coord, + dagAgents: []string{"go-developer", "react-developer"}, + agentName: "go-developer", + modelID: "", + humanTools: newTestHumanToolCallbacks(), + } + reactCtx := &mcpContext{ + workingDir: t.TempDir(), + coord: coord, + dagAgents: []string{"go-developer", "react-developer"}, + agentName: "react-developer", + modelID: "", + humanTools: newTestHumanToolCallbacks(), + } + + type toolCallResult struct { + result *mcp.CallToolResult + err error + } + + goToolCtx, cancelGoTool := context.WithCancel(context.Background()) + reactToolCtx, cancelReactTool := context.WithCancel(context.Background()) + defer cancelGoTool() + defer cancelReactTool() + + goDone := make(chan toolCallResult, 1) + go func() { + result, _, errCall := goCtx.askUserQuestionHandler(goToolCtx, nil, askUserQuestionArgs{ + Questions: []questionItem{newTestQuestionItemArgs("Go question?", []string{"A", "B"})}, + }) + goDone <- toolCallResult{result: result, err: errCall} + }() + + synctest.Wait() + require.True(t, coord.State().NeedsHumanInput()) + firstPromptToken := waitForSessionPromptToken(t, coord) + + reactDone := make(chan toolCallResult, 1) + go func() { + result, _, errCall := reactCtx.askUserWorkGateHandler(reactToolCtx, nil, askUserWorkGateArgs{Summary: "Parallel summary"}) + reactDone <- toolCallResult{result: result, err: errCall} + }() + + synctest.Wait() + select { + case result := <-reactDone: + cancelGoTool() + cancelReactTool() + synctest.Wait() + t.Fatalf("second human-input tool call returned before the first prompt was answered: %v", result.err) + default: + } + + require.True(t, coord.RespondIfCurrent(firstPromptToken, "A")) + synctest.Wait() + + firstResult := <-goDone + require.NoError(t, firstResult.err) + require.NotNil(t, firstResult.result) + + secondPromptToken := waitForSessionPromptToken(t, coord) + assert.NotEqual(t, firstPromptToken, secondPromptToken) + require.NotNil(t, coord.State().MultiChoiceQuestion) + assert.True(t, coord.State().MultiChoiceQuestion.IsWorkGate) + + require.True(t, coord.RespondIfCurrent(secondPromptToken, workGateApprovalText)) + synctest.Wait() + + secondResult := <-reactDone + require.NoError(t, secondResult.err) + require.NotNil(t, secondResult.result) + }) +} + func TestFindSnippetsNoLanguage(t *testing.T) { dir := t.TempDir() snippetsDir := filepath.Join(dir, ".sgai", "snippets") @@ -2237,13 +2327,11 @@ func TestUpdateWorkflowStatePendingTodosDoesNotMutateStatusOrStartWatchdog(t *te })} })) require.NoError(t, err) - coord.SetAgentCancel(func() {}) result, err := updateWorkflowState(coord, "builder", newTestUpdateWorkflowStateArgs("agent-done", "", "")) require.NoError(t, err) assert.Contains(t, result, "pending TODO") assert.Equal(t, state.StatusWorking, coord.State().Status) - assert.False(t, coord.IsShuttingDown()) reloaded, err := state.NewCoordinator(stateFile) require.NoError(t, err) @@ -2269,6 +2357,39 @@ func TestUpdateWorkflowStateUsesCallerAgentForPendingTodos(t *testing.T) { assert.Equal(t, state.StatusWorking, coord.State().Status) } +func TestUpdateWorkflowStateUsesTodosByAgentForCaller(t *testing.T) { + stateFile := filepath.Join(t.TempDir(), "state.json") + coord, err := state.NewCoordinatorWith(stateFile, updated(newTestWorkflow(), func(workflow *state.Workflow) { + workflow.Status = state.StatusWorking + workflow.CurrentAgent = "go-developer, react-developer" + workflow.TodosByAgent = map[string][]state.TodoItem{ + "go-developer": { + updated(newTestTodoItem(), func(todo *state.TodoItem) { + todo.Content = "unfinished go task" + todo.Status = "pending" + todo.Priority = "high" + }), + }, + "react-developer": { + updated(newTestTodoItem(), func(todo *state.TodoItem) { + todo.Content = "finished react task" + todo.Status = "completed" + todo.Priority = "high" + }), + }, + } + })) + require.NoError(t, err) + + goResult, errGo := updateWorkflowState(coord, "go-developer", newTestUpdateWorkflowStateArgs(state.StatusAgentDone, "", "")) + require.NoError(t, errGo) + assert.Contains(t, goResult, "pending TODO") + + reactResult, errReact := updateWorkflowState(coord, "react-developer", newTestUpdateWorkflowStateArgs(state.StatusAgentDone, "", "")) + require.NoError(t, errReact) + assert.NotContains(t, reactResult, "pending TODO") +} + func TestUpdateWorkflowStateCoordinatorUsesProjectTodos(t *testing.T) { stateFile := filepath.Join(t.TempDir(), "state.json") coord, err := state.NewCoordinatorWith(stateFile, updated(newTestWorkflow(), func(workflow *state.Workflow) { @@ -2313,6 +2434,45 @@ func TestUpdateWorkflowStateClearsTaskOnComplete(t *testing.T) { assert.Empty(t, wf.Task) } +func TestUpdateWorkflowStateAggregatesParallelBatchState(t *testing.T) { + stateFile := filepath.Join(t.TempDir(), "state.json") + coord, err := state.NewCoordinatorWith(stateFile, updated(newTestWorkflow(), func(workflow *state.Workflow) { + workflow.Status = state.StatusWorking + workflow.CurrentAgent = "go-developer, react-developer" + })) + require.NoError(t, err) + + goResult, errGo := updateWorkflowState(coord, "go-developer", newTestUpdateWorkflowStateArgs(state.StatusAgentDone, "finished go work", "")) + require.NoError(t, errGo) + assert.Contains(t, goResult, "Status: working") + + wf := coord.State() + assert.Equal(t, state.StatusWorking, wf.Status) + assert.Empty(t, wf.Task) + assert.Equal(t, state.StatusAgentDone, wf.AgentStates["go-developer"].Status) + assert.Empty(t, wf.AgentStates["go-developer"].Task) + + reactResult, errReact := updateWorkflowState(coord, "react-developer", newTestUpdateWorkflowStateArgs(state.StatusWorking, "reviewing frontend", "")) + require.NoError(t, errReact) + assert.Contains(t, reactResult, "Status: working") + + wf = coord.State() + assert.Equal(t, state.StatusWorking, wf.Status) + assert.Empty(t, wf.Task) + assert.Equal(t, state.StatusWorking, wf.AgentStates["react-developer"].Status) + assert.Equal(t, "reviewing frontend", wf.AgentStates["react-developer"].Task) + + reactDoneResult, errReactDone := updateWorkflowState(coord, "react-developer", newTestUpdateWorkflowStateArgs(state.StatusAgentDone, "frontend done", "")) + require.NoError(t, errReactDone) + assert.Contains(t, reactDoneResult, "Status: agent-done") + + wf = coord.State() + assert.Equal(t, state.StatusAgentDone, wf.Status) + assert.Empty(t, wf.Task) + assert.Equal(t, state.StatusAgentDone, wf.AgentStates["react-developer"].Status) + assert.Empty(t, wf.AgentStates["react-developer"].Task) +} + func TestSendMessageInvalidAgent(t *testing.T) { stateFile := filepath.Join(t.TempDir(), "state.json") coord, err := state.NewCoordinatorWith(stateFile, newTestWorkflow()) @@ -2331,6 +2491,38 @@ func TestSendMessageValidAgent(t *testing.T) { assert.Contains(t, result, "sent") } +func TestCheckInboxForCallerUsesExplicitModel(t *testing.T) { + stateFile := filepath.Join(t.TempDir(), "state.json") + coord, err := state.NewCoordinatorWith(stateFile, updated(newTestWorkflow(), func(workflow *state.Workflow) { + workflow.CurrentModel = "reviewer:model-b" + workflow.Messages = []state.Message{ + updated(newTestMessage(), func(message *state.Message) { + message.ID = 1 + message.ToAgent = "builder:model-a" + message.Body = "check the patch" + }), + } + })) + require.NoError(t, err) + + result, errCheck := checkInboxForCaller(coord, callerContext{agentName: "builder", modelID: "builder:model-a"}) + require.NoError(t, errCheck) + assert.Contains(t, result, "check the patch") + assert.True(t, coord.State().Messages[0].Read) +} + +func TestSendMessageForCallerUsesExplicitModel(t *testing.T) { + stateFile := filepath.Join(t.TempDir(), "state.json") + coord, err := state.NewCoordinatorWith(stateFile, updated(newTestWorkflow(), func(workflow *state.Workflow) { + workflow.CurrentModel = "reviewer:model-b" + })) + require.NoError(t, err) + + _, errSend := sendMessageForCaller("", coord, []string{"coordinator", "builder"}, callerContext{agentName: "builder", modelID: "builder:model-a"}, "coordinator", "hello from model a") + require.NoError(t, errSend) + assert.Equal(t, "builder:model-a", coord.State().Messages[0].FromAgent) +} + func TestListSnippetLanguagesNoDir(t *testing.T) { result, err := listSnippetLanguages(filepath.Join(t.TempDir(), "nonexistent")) require.NoError(t, err) diff --git a/cmd/sgai/serve.go b/cmd/sgai/serve.go index 7dffe2e..9a9f790 100644 --- a/cmd/sgai/serve.go +++ b/cmd/sgai/serve.go @@ -322,7 +322,7 @@ func workspaceListSummaryFromState(workspacePath string, wfState *state.Workflow if currentAgent == "" { currentAgent = "Unknown" } - status := wfState.Status + status := visibleWorkflowStatus(wfState) if status == "" { status = "-" } @@ -335,7 +335,7 @@ func workspaceListSummaryFromState(workspacePath string, wfState *state.Workflow InteractiveAuto: interactiveAuto, CurrentAgent: currentAgent, CurrentModel: resolveCurrentModel(workspacePath, wfState), - Task: wfState.Task, + Task: visibleWorkflowTask(wfState), LatestProgress: getLatestProgress(wfState.Progress), HumanMessage: wfState.HumanMessage, } @@ -629,7 +629,7 @@ func (s *Server) stopSession(workspacePath string) { } } -func (s *Server) finishSessionRun(workspacePath string, sess *session, coord *state.Coordinator) { +func (s *Server) finishSessionRun(workspacePath string, sess *session, _ *state.Coordinator) { if sess == nil { return } @@ -644,9 +644,6 @@ func (s *Server) finishSessionRun(workspacePath string, sess *session, coord *st if closeFn != nil { sess.mcpCloseOnce.Do(closeFn) } - if coord != nil { - coord.Stop() - } s.clearEverStartedOnCompletion(workspacePath) if !skipExitNotification { s.notifyWorkspaceListChange(workspacePath) @@ -654,13 +651,14 @@ func (s *Server) finishSessionRun(workspacePath string, sess *session, coord *st } func badgeStatus(wfState *state.Workflow, running bool) (class, text string) { + status := visibleWorkflowStatus(wfState) if wfState.NeedsHumanInput() { return "badge-needs-input", "Needs Input" } - if running || wfState.Status == state.StatusWorking || wfState.Status == state.StatusAgentDone { + if running || status == state.StatusWorking || status == state.StatusAgentDone { return "badge-running", "Running" } - if !running && wfState.Status == state.StatusComplete { + if !running && status == state.StatusComplete { return "badge-complete", "Complete" } return "badge-stopped", "Stopped" @@ -863,6 +861,7 @@ func getWorkflowSVG(dir, currentAgent string) string { if retrospectiveEnabled(metadata.Retrospective) { d.injectRetrospectiveEdge() } + addCurrentAgentsToGraph(d, currentAgent) dotContent := d.toDOT() @@ -874,6 +873,12 @@ func getWorkflowSVG(dir, currentAgent string) string { return renderDotToSVG(dotContent) } +func addCurrentAgentsToGraph(d *dag, currentAgent string) { + for _, agent := range splitCurrentAgents(currentAgent) { + d.ensureNode(agent) + } +} + func (s *Server) getWorkflowSVGCached(dir, currentAgent string) string { cacheKey := dir + "|" + currentAgent if cached, ok := s.svgCache.get(cacheKey); ok { @@ -1104,14 +1109,16 @@ func extractSubject(body string) string { } func injectCurrentAgentStyle(dot, currentAgent string) string { - agentLine := fmt.Sprintf(" %q", currentAgent) - styledLine := fmt.Sprintf(" %q [style=filled, fillcolor=\"#10b981\", fontcolor=white]", currentAgent) - - if !strings.Contains(dot, agentLine) { - return dot + result := dot + for _, agent := range splitCurrentAgents(currentAgent) { + agentLine := fmt.Sprintf(" %q", agent) + styledLine := fmt.Sprintf(" %q [style=filled, fillcolor=\"#10b981\", fontcolor=white]", agent) + if !strings.Contains(result, agentLine) { + continue + } + result = strings.Replace(result, agentLine, styledLine, 1) } - - return strings.Replace(dot, agentLine, styledLine, 1) + return result } func injectLightTheme(dot string) string { diff --git a/cmd/sgai/serve_api.go b/cmd/sgai/serve_api.go index b728e1b..9532653 100644 --- a/cmd/sgai/serve_api.go +++ b/cmd/sgai/serve_api.go @@ -309,7 +309,7 @@ type apiWorkspaceFullState struct { Events []apiEventEntry `json:"events"` Messages []apiMessageEntry `json:"messages"` ProjectTodos []apiTodoEntry `json:"projectTodos"` - AgentTodos []apiTodoEntry `json:"agentTodos"` + AgentTodoSections []apiAgentTodoSection `json:"agentTodoSections"` Forks []apiForkEntry `json:"forks,omitempty"` Log []apiLogEntry `json:"log"` PendingQuestion *apiPendingQuestionResponse `json:"pendingQuestion,omitempty"` @@ -493,13 +493,19 @@ func (s *Server) buildWorkspaceListEntry(ws workspaceInfo, groups []workspaceGro interactiveAuto := wfState.InteractionMode == state.ModeSelfDrive || wfState.InteractionMode == state.ModeContinuous badgeClass, badgeText := badgeStatus(&wfState, ws.Running) needsInput := wfState.NeedsHumanInput() + humanMessage := wfState.HumanMessage + if coord := s.sessionCoordinator(ws.Directory); coord != nil { + humanInput := currentHumanInputSnapshot(coord) + needsInput = humanInput.needsInput() + humanMessage = humanInput.humanMessage + } currentAgent := wfState.CurrentAgent if currentAgent == "" { currentAgent = "Unknown" } - status := wfState.Status + status := visibleWorkflowStatus(&wfState) if status == "" { status = "-" } @@ -531,12 +537,12 @@ func (s *Server) buildWorkspaceListEntry(ws workspaceInfo, groups []workspaceGro ContinuousMode: readContinuousModePrompt(ws.Directory) != "", CurrentAgent: currentAgent, CurrentModel: resolveCurrentModel(ws.Directory, &wfState), - Task: wfState.Task, + Task: visibleWorkflowTask(&wfState), Title: titleState.Title, ComputedTitle: workspaceComputedTitle(ws, groups, titleState), TotalExecTime: calculateTotalExecutionTime(wfState.AgentSequence, ws.Running, getLastActivityTime(wfState.Progress)), LatestProgress: getLatestProgress(wfState.Progress), - HumanMessage: wfState.HumanMessage, + HumanMessage: humanMessage, Forks: nil, RepositoryAction: repositoryAction.api(ws.DirName), } @@ -555,13 +561,14 @@ func (s *Server) buildWorkspaceFullState(ws workspaceInfo, groups []workspaceGro interactiveAuto := wfState.InteractionMode == state.ModeSelfDrive || wfState.InteractionMode == state.ModeContinuous badgeClass, badgeText := badgeStatus(&wfState, ws.Running) needsInput := wfState.NeedsHumanInput() + humanMessage := wfState.HumanMessage currentAgent := wfState.CurrentAgent if currentAgent == "" { currentAgent = "Unknown" } - status := wfState.Status + status := visibleWorkflowStatus(&wfState) if status == "" { status = "-" } @@ -598,27 +605,11 @@ func (s *Server) buildWorkspaceFullState(ws workspaceInfo, groups []workspaceGro } var pendingQuestion *apiPendingQuestionResponse - if wfState.NeedsHumanInput() { - coord := s.sessionCoordinator(ws.Directory) - agentName := currentAgent - var questions []apiQuestionItem - if wfState.MultiChoiceQuestion != nil { - questions = make([]apiQuestionItem, 0, len(wfState.MultiChoiceQuestion.Questions)) - for _, q := range wfState.MultiChoiceQuestion.Questions { - questions = append(questions, apiQuestionItem{ - Question: q.Question, - Choices: q.Choices, - MultiSelect: q.MultiSelect, - }) - } - } - pendingQuestion = &apiPendingQuestionResponse{ - PromptToken: promptTokenForState(coord, &wfState), - Type: questionType(&wfState), - AgentName: agentName, - Message: wfState.HumanMessage, - Questions: questions, - } + if coord := s.sessionCoordinator(ws.Directory); coord != nil { + humanInput := currentHumanInputSnapshot(coord) + needsInput = humanInput.needsInput() + humanMessage = humanInput.humanMessage + pendingQuestion = humanInput.pendingQuestion(currentAgent) } actionState := loadActionsForAPI(ws.Directory) @@ -645,7 +636,7 @@ func (s *Server) buildWorkspaceFullState(ws workspaceInfo, groups []workspaceGro ContinuousMode: readContinuousModePrompt(ws.Directory) != "", CurrentAgent: currentAgent, CurrentModel: resolveCurrentModel(ws.Directory, &wfState), - Task: wfState.Task, + Task: visibleWorkflowTask(&wfState), GoalContent: goalContent, Title: titleState.Title, ComputedTitle: workspaceComputedTitle(ws, groups, titleState), @@ -656,7 +647,7 @@ func (s *Server) buildWorkspaceFullState(ws workspaceInfo, groups []workspaceGro SVGHash: s.getWorkflowSVGHashCached(ws.Directory, currentAgent), TotalExecTime: totalExecTime, LatestProgress: getLatestProgress(wfState.Progress), - HumanMessage: wfState.HumanMessage, + HumanMessage: humanMessage, AgentSequence: agentSeq, Cost: wfState.Cost, ModelStatuses: modelStatuses, @@ -664,7 +655,7 @@ func (s *Server) buildWorkspaceFullState(ws workspaceInfo, groups []workspaceGro Events: events, Messages: messages, ProjectTodos: convertTodosForAPI(wfState.ProjectTodos), - AgentTodos: convertTodosForAPI(wfState.Todos), + AgentTodoSections: convertAgentTodoSectionsForAPI(wfState.CurrentAgent, wfState.TodosByAgent), Log: logLines, PendingQuestion: pendingQuestion, Actions: actionState.Actions, @@ -1324,6 +1315,11 @@ type apiTodoEntry struct { Priority string `json:"priority"` } +type apiAgentTodoSection struct { + Agent string `json:"agent"` + Todos []apiTodoEntry `json:"todos"` +} + func convertTodosForAPI(todos []state.TodoItem) []apiTodoEntry { result := make([]apiTodoEntry, 0, len(todos)) for _, t := range todos { @@ -1337,6 +1333,18 @@ func convertTodosForAPI(todos []state.TodoItem) []apiTodoEntry { return result } +func convertAgentTodoSectionsForAPI(currentAgent string, todosByAgent map[string][]state.TodoItem) []apiAgentTodoSection { + currentAgents := splitCurrentAgents(currentAgent) + result := make([]apiAgentTodoSection, 0, len(currentAgents)) + for _, agent := range currentAgents { + result = append(result, apiAgentTodoSection{ + Agent: agent, + Todos: convertTodosForAPI(todosByAgent[agent]), + }) + } + return result +} + type apiLogEntry struct { Prefix string `json:"prefix"` Text string `json:"text"` @@ -1375,26 +1383,79 @@ type apiForkEntry struct { ComputedTitle string `json:"computedTitle,omitempty"` } -func promptTokenForState(coord *state.Coordinator, wfState *state.Workflow) string { - if coord == nil || !wfState.NeedsHumanInput() { - return "" +type apiHumanInputSnapshot struct { + promptToken string + humanMessage string + askingAgent string + question *state.MultiChoiceQuestion +} + +func currentHumanInputSnapshot(coord *state.Coordinator) apiHumanInputSnapshot { + if coord == nil { + var humanInput apiHumanInputSnapshot + return humanInput + } + promptToken, humanMessage, askingAgent, question := coord.CurrentHumanInput() + return apiHumanInputSnapshot{ + promptToken: promptToken, + humanMessage: humanMessage, + askingAgent: askingAgent, + question: question, + } +} + +func (h apiHumanInputSnapshot) needsInput() bool { + return h.promptToken != "" || h.humanMessage != "" || h.question != nil +} + +func (h apiHumanInputSnapshot) pendingQuestion(currentAgent string) *apiPendingQuestionResponse { + if !h.needsInput() { + return nil + } + return &apiPendingQuestionResponse{ + PromptToken: h.promptToken, + Type: questionType(h.question, h.humanMessage), + AgentName: pendingQuestionAgentName(h.askingAgent, currentAgent), + Message: h.humanMessage, + Questions: apiQuestionItems(h.question), + } +} + +func pendingQuestionAgentName(askingAgent, currentAgent string) string { + if askingAgent != "" { + return askingAgent } - return coord.CurrentPromptToken() + return currentAgent } -func questionType(wfState *state.Workflow) string { - if wfState.MultiChoiceQuestion != nil { - if wfState.MultiChoiceQuestion.IsWorkGate { +func questionType(question *state.MultiChoiceQuestion, humanMessage string) string { + if question != nil { + if question.IsWorkGate { return "work-gate" } return "multi-choice" } - if wfState.HumanMessage != "" { + if humanMessage != "" { return "free-text" } return "" } +func apiQuestionItems(question *state.MultiChoiceQuestion) []apiQuestionItem { + if question == nil { + return nil + } + items := make([]apiQuestionItem, 0, len(question.Questions)) + for _, q := range question.Questions { + items = append(items, apiQuestionItem{ + Question: q.Question, + Choices: q.Choices, + MultiSelect: q.MultiSelect, + }) + } + return items +} + type apiQuestionItem struct { Question string `json:"question"` Choices []string `json:"choices"` @@ -1410,7 +1471,7 @@ type apiPendingQuestionResponse struct { } type apiRespondRequest struct { - PromptToken string `json:"promptToken,omitempty"` + PromptToken string `json:"promptToken"` Answer string `json:"answer"` SelectedChoices []string `json:"selectedChoices"` } @@ -1460,7 +1521,7 @@ func (s *Server) handleRespondViaCoordinator(w http.ResponseWriter, workspacePat switch { case errors.Is(errRespond, errNoPendingQuestion), errors.Is(errRespond, errQuestionNotAvailable): statusCode = http.StatusConflict - case errors.Is(errRespond, errResponseCannotBeEmpty): + case errors.Is(errRespond, errResponseCannotBeEmpty), errors.Is(errRespond, errPromptTokenRequired): statusCode = http.StatusBadRequest } http.Error(w, errRespond.Error(), statusCode) @@ -2145,6 +2206,9 @@ func (s *Server) coordinatorModelFromWorkspace(workspace string) string { } func resolveCurrentModel(workspacePath string, wfState *state.Workflow) string { + if hasParallelCurrentAgents(wfState.CurrentAgent) { + return "" + } if wfState.CurrentModel != "" { return wfState.CurrentModel } diff --git a/cmd/sgai/serve_api_test.go b/cmd/sgai/serve_api_test.go index 31ff4ba..1b8a2ec 100644 --- a/cmd/sgai/serve_api_test.go +++ b/cmd/sgai/serve_api_test.go @@ -153,10 +153,8 @@ func attachRunningSessionCoordinator(t *testing.T, srv *Server, wsDir string, wf func attachSessionCoordinatorWithRunning(t *testing.T, srv *Server, wsDir string, wf *state.Workflow, running bool) { t.Helper() statePath := filepath.Join(wsDir, ".sgai", "state.json") - coord := state.NewCoordinatorEmpty(statePath) - require.NoError(t, coord.UpdateState(func(current *state.Workflow) { - *current = *wf - })) + coord, errCoord := state.NewCoordinatorWith(statePath, *wf) + require.NoError(t, errCoord) srv.mu.Lock() srv.sessions[wsDir] = newTestServeSession(coord, running) srv.mu.Unlock() @@ -214,7 +212,7 @@ func startWaitingSessionQuestion(t *testing.T, srv *Server, wsDir string, questi ctx, cancel := context.WithCancel(context.Background()) errCh := make(chan error, 1) go func() { - _, err := coord.AskAndWait(ctx, question, humanMessage) + _, err := coord.AskAndWait(ctx, question, humanMessage, "coordinator") errCh <- err }() <-ready @@ -446,7 +444,7 @@ func TestQuestionType(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := questionType(&tt.wfState) + result := questionType(tt.wfState.MultiChoiceQuestion, tt.wfState.HumanMessage) assert.Equal(t, tt.expected, result) }) } @@ -518,58 +516,44 @@ func TestConvertEventsForAPIBoost(t *testing.T) { }) } -func TestResolveCurrentModelVariants(t *testing.T) { - t.Run("fromState", func(t *testing.T) { - wfState := workflowWith(func(workflow *state.Workflow) { workflow.CurrentModel = "claude-opus-4" }) - result := resolveCurrentModel("/some/path", &wfState) - assert.Equal(t, "claude-opus-4", result) - }) - - t.Run("noAgent", func(t *testing.T) { - wfState := newTestWorkflow() - result := resolveCurrentModel("/some/path", &wfState) - assert.Empty(t, result) - }) - - t.Run("fromGoalFile", func(t *testing.T) { - dir := t.TempDir() - goalPath := filepath.Join(dir, "GOAL.md") - content := "---\nmodels:\n coordinator: claude-opus-4\n---\n# Goal" - require.NoError(t, os.WriteFile(goalPath, []byte(content), 0o644)) - - wfState := workflowWith(func(workflow *state.Workflow) { workflow.CurrentAgent = "coordinator" }) - result := resolveCurrentModel(dir, &wfState) - assert.Equal(t, "claude-opus-4", result) - }) - - t.Run("agentNotInGoal", func(t *testing.T) { - dir := t.TempDir() - goalPath := filepath.Join(dir, "GOAL.md") - content := "---\nmodels:\n coordinator: claude-opus-4\n---\n# Goal" - require.NoError(t, os.WriteFile(goalPath, []byte(content), 0o644)) - - wfState := workflowWith(func(workflow *state.Workflow) { workflow.CurrentAgent = "developer" }) - result := resolveCurrentModel(dir, &wfState) +func TestConvertAgentTodoSectionsForAPI(t *testing.T) { + t.Run("noActiveAgents", func(t *testing.T) { + result := convertAgentTodoSectionsForAPI("", map[string][]state.TodoItem{ + "go-developer": { + todoItemWith(func(todo *state.TodoItem) { + todo.Content = "ignored" + todo.Status = "pending" + todo.Priority = "high" + }), + }, + }) assert.Empty(t, result) }) - t.Run("withExplicitModel", func(t *testing.T) { - wf := workflowWith(func(workflow *state.Workflow) { workflow.CurrentModel = "opus-4" }) - result := resolveCurrentModel("/tmp", &wf) - assert.Equal(t, "opus-4", result) - }) - - t.Run("noAgentReturnsEmpty", func(t *testing.T) { - wf := newTestWorkflow() - result := resolveCurrentModel("/tmp", &wf) - assert.Empty(t, result) - }) + t.Run("singleActiveAgent", func(t *testing.T) { + result := convertAgentTodoSectionsForAPI("go-developer", map[string][]state.TodoItem{ + "go-developer": { + todoItemWith(func(todo *state.TodoItem) { + todo.Content = "grouped task" + todo.Status = "in_progress" + todo.Priority = "high" + }), + }, + "react-developer": { + todoItemWith(func(todo *state.TodoItem) { + todo.Content = "hidden task" + todo.Status = "pending" + todo.Priority = "medium" + }), + }, + }) - t.Run("noModelReturnsEmpty", func(t *testing.T) { - dir := t.TempDir() - wf := newTestWorkflow() - result := resolveCurrentModel(dir, &wf) - assert.Empty(t, result) + if assert.Len(t, result, 1) { + assert.Equal(t, "go-developer", result[0].Agent) + if assert.Len(t, result[0].Todos, 1) { + assert.Equal(t, "grouped task", result[0].Todos[0].Content) + } + } }) } @@ -1065,42 +1049,140 @@ func TestResolveRootForDeleteForkForkReturnsPath(t *testing.T) { } func TestQuestionTypeFreeformMessage(t *testing.T) { - wf := workflowWith(func(workflow *state.Workflow) { - workflow.Status = state.StatusWorking - workflow.HumanMessage = "What do you think?" - }) - assert.Equal(t, "free-text", questionType(&wf)) + assert.Equal(t, "free-text", questionType(nil, "What do you think?")) } func TestQuestionTypeMultiChoiceQuestions(t *testing.T) { - wf := workflowWith(func(workflow *state.Workflow) { - workflow.Status = state.StatusWorking - workflow.MultiChoiceQuestion = multiChoiceQuestionWith(func(question *state.MultiChoiceQuestion) { + question := multiChoiceQuestionWith(func(question *state.MultiChoiceQuestion) { + question.Questions = []state.QuestionItem{ + questionItemWith(func(item *state.QuestionItem) { + item.Question = "Pick one" + item.Choices = []string{"A", "B"} + }), + } + }) + assert.Equal(t, "multi-choice", questionType(question, "")) +} + +func TestQuestionTypeWorkGateFlag(t *testing.T) { + question := multiChoiceQuestionWith(func(question *state.MultiChoiceQuestion) { + question.IsWorkGate = true + question.Questions = []state.QuestionItem{ + questionItemWith(func(item *state.QuestionItem) { + item.Question = "Approve?" + item.Choices = []string{"Yes", "No"} + }), + } + }) + assert.Equal(t, "work-gate", questionType(question, "")) +} + +func TestHandleAPIWorkspaceStateUsesCurrentPromptSnapshotAfterPromotion(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + srv, rootDir := setupTestServer(t) + wsDir := setupTestWorkspace(t, srv, rootDir, "pq-promotion-snapshot") + require.NoError(t, os.WriteFile(filepath.Join(wsDir, "GOAL.md"), []byte("---\n---\n# Goal"), 0o644)) + + coord := state.NewCoordinatorEmpty(filepath.Join(wsDir, ".sgai", "state.json")) + srv.mu.Lock() + srv.sessions[wsDir] = newTestServeSession(coord, true) + srv.mu.Unlock() + + firstQuestion := multiChoiceQuestionWith(func(question *state.MultiChoiceQuestion) { question.Questions = []state.QuestionItem{ questionItemWith(func(item *state.QuestionItem) { - item.Question = "Pick one" - item.Choices = []string{"A", "B"} + item.Question = "First question" + item.Choices = []string{"A"} }), } }) - }) - assert.Equal(t, "multi-choice", questionType(&wf)) -} + firstCtx, cancelFirst := context.WithCancel(context.Background()) + defer cancelFirst() + firstAnswerCh := make(chan string, 1) + errFirstCh := make(chan error, 1) + go func() { + answer, errWait := coord.AskAndWait(firstCtx, firstQuestion, "first message", "go-developer") + if errWait != nil { + errFirstCh <- errWait + return + } + firstAnswerCh <- answer + }() -func TestQuestionTypeWorkGateFlag(t *testing.T) { - wf := workflowWith(func(workflow *state.Workflow) { - workflow.Status = state.StatusWorking - workflow.MultiChoiceQuestion = multiChoiceQuestionWith(func(question *state.MultiChoiceQuestion) { - question.IsWorkGate = true + synctest.Wait() + firstToken := coord.CurrentPromptToken() + require.NotEmpty(t, firstToken) + + secondQuestion := multiChoiceQuestionWith(func(question *state.MultiChoiceQuestion) { question.Questions = []state.QuestionItem{ questionItemWith(func(item *state.QuestionItem) { - item.Question = "Approve?" - item.Choices = []string{"Yes", "No"} + item.Question = "Second question" + item.Choices = []string{"B"} }), } }) + secondCtx, cancelSecond := context.WithCancel(context.Background()) + defer cancelSecond() + secondAnswerCh := make(chan string, 1) + errSecondCh := make(chan error, 1) + go func() { + answer, errWait := coord.AskAndWait(secondCtx, secondQuestion, "second message", "react-developer") + if errWait != nil { + errSecondCh <- errWait + return + } + secondAnswerCh <- answer + }() + + synctest.Wait() + staleState := coord.State() + require.Equal(t, "first message", staleState.HumanMessage) + require.Equal(t, "go-developer", staleState.HumanInputAgent) + require.NotNil(t, staleState.MultiChoiceQuestion) + require.Len(t, staleState.MultiChoiceQuestion.Questions, 1) + require.Equal(t, "First question", staleState.MultiChoiceQuestion.Questions[0].Question) + + require.True(t, coord.RespondIfCurrent(firstToken, "first answer")) + synctest.Wait() + select { + case answer := <-firstAnswerCh: + require.Equal(t, "first answer", answer) + case errWait := <-errFirstCh: + require.NoError(t, errWait) + default: + t.Fatal("first prompt did not complete") + } + + secondToken := coord.CurrentPromptToken() + require.NotEmpty(t, secondToken) + require.NotEqual(t, firstToken, secondToken) + + w := serveHTTP(srv, "GET", "/api/v1/workspaces/pq-promotion-snapshot/state", "") + assert.Equal(t, http.StatusOK, w.Code) + + var workspace apiWorkspaceFullState + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &workspace)) + require.NotNil(t, workspace.PendingQuestion) + assert.Equal(t, secondToken, workspace.PendingQuestion.PromptToken) + assert.Equal(t, "second message", workspace.PendingQuestion.Message) + assert.Equal(t, "react-developer", workspace.PendingQuestion.AgentName) + assert.Equal(t, "multi-choice", workspace.PendingQuestion.Type) + require.Len(t, workspace.PendingQuestion.Questions, 1) + assert.Equal(t, "Second question", workspace.PendingQuestion.Questions[0].Question) + assert.NotEqual(t, staleState.HumanMessage, workspace.PendingQuestion.Message) + assert.NotEqual(t, staleState.MultiChoiceQuestion.Questions[0].Question, workspace.PendingQuestion.Questions[0].Question) + + cancelSecond() + synctest.Wait() + select { + case errWait := <-errSecondCh: + require.ErrorIs(t, errWait, context.Canceled) + case answer := <-secondAnswerCh: + t.Fatalf("expected second prompt cancellation, got %q", answer) + default: + t.Fatal("second prompt did not complete") + } }) - assert.Equal(t, "work-gate", questionType(&wf)) } func TestLoadWorkspaceStateInvalidJSONReturnsWorkingFallback(t *testing.T) { @@ -2879,20 +2961,30 @@ func TestBuildWorkspaceFullStateWithMessages(t *testing.T) { assert.Len(t, result.Messages, 2) } -func TestBuildWorkspaceFullStateWithTodos(t *testing.T) { +func TestBuildWorkspaceFullStateWithAgentTodoSections(t *testing.T) { srv, rootDir := setupTestServer(t) wsDir := setupTestWorkspace(t, srv, rootDir, "todos-ws") require.NoError(t, os.WriteFile(filepath.Join(wsDir, "GOAL.md"), []byte("---\n---\n# Goal"), 0o644)) statePath := filepath.Join(wsDir, ".sgai", "state.json") _, errCoord := state.NewCoordinatorWith(statePath, workflowWith(func(workflow *state.Workflow) { workflow.Status = state.StatusComplete + workflow.CurrentAgent = "react-developer, go-developer" workflow.Todos = []state.TodoItem{ todoItemWith(func(todo *state.TodoItem) { - todo.Content = "task1" + todo.Content = "stale visible task" todo.Status = "completed" todo.Priority = "high" }), } + workflow.TodosByAgent = map[string][]state.TodoItem{ + "go-developer": { + todoItemWith(func(todo *state.TodoItem) { + todo.Content = "grouped task" + todo.Status = "in_progress" + todo.Priority = "high" + }), + }, + } workflow.ProjectTodos = []state.TodoItem{ todoItemWith(func(todo *state.TodoItem) { todo.Content = "proj-task" @@ -2909,8 +3001,28 @@ func TestBuildWorkspaceFullStateWithTodos(t *testing.T) { workspace.HasWorkspace = true }) result := srv.buildWorkspaceFullState(ws, nil) - assert.Len(t, result.AgentTodos, 1) - assert.Len(t, result.ProjectTodos, 1) + payload, errMarshal := json.Marshal(result) + require.NoError(t, errMarshal) + + var apiResult struct { + ProjectTodos []apiTodoEntry `json:"projectTodos"` + AgentTodoSections []struct { + Agent string `json:"agent"` + Todos []apiTodoEntry `json:"todos"` + } `json:"agentTodoSections"` + AgentTodos json.RawMessage `json:"agentTodos"` + } + require.NoError(t, json.Unmarshal(payload, &apiResult)) + assert.Len(t, apiResult.ProjectTodos, 1) + if assert.Len(t, apiResult.AgentTodoSections, 2) { + assert.Equal(t, "react-developer", apiResult.AgentTodoSections[0].Agent) + assert.Empty(t, apiResult.AgentTodoSections[0].Todos) + assert.Equal(t, "go-developer", apiResult.AgentTodoSections[1].Agent) + if assert.Len(t, apiResult.AgentTodoSections[1].Todos, 1) { + assert.Equal(t, "grouped task", apiResult.AgentTodoSections[1].Todos[0].Content) + } + } + assert.Empty(t, apiResult.AgentTodos) } func TestHandleAPIWorkspaceStateFullIntegration(t *testing.T) { @@ -3023,6 +3135,88 @@ func TestHandleAPIWorkspaceStatePendingQuestionUsesPromptToken(t *testing.T) { require.ErrorIs(t, <-errCh, context.Canceled) } +func TestHandleAPIWorkspaceStatePendingQuestionUsesAskingAgentDuringParallelBatch(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + srv, rootDir := setupTestServer(t) + wsDir := setupTestWorkspace(t, srv, rootDir, "pq-parallel-agent") + require.NoError(t, os.WriteFile(filepath.Join(wsDir, "GOAL.md"), []byte("---\n---\n# Goal"), 0o644)) + + statePath := filepath.Join(wsDir, ".sgai", "state.json") + coord, errCoord := state.NewCoordinatorWith(statePath, workflowWith(func(workflow *state.Workflow) { + workflow.Status = state.StatusWorking + workflow.InteractionMode = state.ModeBrainstorming + workflow.CurrentAgent = "go-developer, react-developer" + })) + require.NoError(t, errCoord) + + srv.mu.Lock() + srv.sessions[wsDir] = newTestServeSession(coord, true) + srv.mu.Unlock() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errCh := make(chan error, 1) + go func() { + mcpCtx := &mcpContext{ + workingDir: wsDir, + coord: coord, + dagAgents: []string{"go-developer", "react-developer"}, + agentName: "go-developer", + modelID: "", + humanTools: newTestHumanToolCallbacks(), + } + _, _, errCall := mcpCtx.askUserQuestionHandler(ctx, nil, askUserQuestionArgs{ + Questions: []questionItem{newTestQuestionItemArgs("Which change?", []string{"A", "B"})}, + }) + errCh <- errCall + }() + + synctest.Wait() + require.True(t, coord.State().NeedsHumanInput()) + promptToken := waitForSessionPromptToken(t, coord) + + w := serveHTTP(srv, "GET", "/api/v1/workspaces/pq-parallel-agent/state", "") + assert.Equal(t, http.StatusOK, w.Code) + + var ws apiWorkspaceFullState + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &ws)) + require.NotNil(t, ws.PendingQuestion) + assert.Equal(t, promptToken, ws.PendingQuestion.PromptToken) + assert.Equal(t, "go-developer, react-developer", ws.CurrentAgent) + assert.Equal(t, "go-developer", ws.PendingQuestion.AgentName) + + require.True(t, coord.RespondIfCurrent(promptToken, "A")) + synctest.Wait() + require.NoError(t, <-errCh) + }) +} + +func TestRespondViaCoordinatorRejectsMissingPromptToken(t *testing.T) { + srv, rootDir := setupTestServer(t) + wsDir := setupTestWorkspace(t, srv, rootDir, "respond-missing-token") + require.NoError(t, os.WriteFile(filepath.Join(wsDir, "GOAL.md"), []byte("---\n---\n# Goal"), 0o644)) + + coord, errCh, cancel := startWaitingSessionQuestion(t, srv, wsDir, multiChoiceQuestionWith(func(question *state.MultiChoiceQuestion) { + question.Questions = []state.QuestionItem{ + questionItemWith(func(item *state.QuestionItem) { + item.Question = "Pick one" + item.Choices = []string{"A", "B"} + }), + } + }), "Pick an option") + defer cancel() + promptToken := waitForSessionPromptToken(t, coord) + + body := `{"answer":"go with A","selectedChoices":["A"]}` + w := serveHTTP(srv, "POST", "/api/v1/workspaces/respond-missing-token/respond", body) + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, w.Body.String(), errPromptTokenRequired.Error()) + assert.True(t, coord.State().NeedsHumanInput()) + + require.True(t, coord.RespondIfCurrent(promptToken, "go with A")) + require.NoError(t, <-errCh) +} + func TestRespondViaCoordinatorFullPath(t *testing.T) { srv, rootDir := setupTestServer(t) wsDir := setupTestWorkspace(t, srv, rootDir, "respond-full") @@ -3037,8 +3231,14 @@ func TestRespondViaCoordinatorFullPath(t *testing.T) { } }), "Pick an option") defer cancel() - body := `{"answer":"go with A","selectedChoices":["A"]}` - w := serveHTTP(srv, "POST", "/api/v1/workspaces/respond-full/respond", body) + promptToken := waitForSessionPromptToken(t, coord) + body, errMarshal := json.Marshal(respondRequestWith(func(request *apiRespondRequest) { + request.PromptToken = promptToken + request.Answer = "go with A" + request.SelectedChoices = []string{"A"} + })) + require.NoError(t, errMarshal) + w := serveHTTP(srv, "POST", "/api/v1/workspaces/respond-full/respond", string(body)) assert.Equal(t, http.StatusOK, w.Code) require.NoError(t, <-errCh) assert.Empty(t, coord.State().HumanMessage) @@ -3062,7 +3262,7 @@ func TestRespondViaCoordinatorWithoutActiveToolCall(t *testing.T) { }) })) - body := `{"answer":"go with A","selectedChoices":["A"]}` + body := `{"promptToken":"stale","answer":"go with A","selectedChoices":["A"]}` w := serveHTTP(srv, "POST", "/api/v1/workspaces/respond-wrong/respond", body) assert.Equal(t, http.StatusConflict, w.Code) } @@ -3604,6 +3804,43 @@ func TestBuildWorkspaceFullStateWithFreeformPending(t *testing.T) { assert.Equal(t, "builder", result.PendingQuestion.AgentName) } +func TestBuildWorkspaceStateUsesAggregatedParallelBatchState(t *testing.T) { + srv, rootDir := setupTestServer(t) + wsDir := setupTestWorkspace(t, srv, rootDir, "parallel-ws") + require.NoError(t, os.WriteFile(filepath.Join(wsDir, "GOAL.md"), []byte("---\n---\n# Goal"), 0o644)) + statePath := filepath.Join(wsDir, ".sgai", "state.json") + _, errCoord := state.NewCoordinatorWith(statePath, workflowWith(func(workflow *state.Workflow) { + workflow.Status = state.StatusAgentDone + workflow.Task = "stale global task" + workflow.CurrentAgent = "go-developer, react-developer" + workflow.AgentStates = map[string]state.AgentExecutionState{ + "go-developer": { + Status: state.StatusAgentDone, + Task: "", + }, + "react-developer": { + Status: state.StatusWorking, + Task: "reviewing frontend", + }, + } + })) + require.NoError(t, errCoord) + + ws := workspaceWith(func(workspace *workspaceInfo) { + workspace.DirName = "parallel-ws" + workspace.Directory = wsDir + workspace.HasWorkspace = true + }) + + list := srv.buildWorkspaceListEntry(ws, nil) + assert.Equal(t, state.StatusWorking, list.Status) + assert.Empty(t, list.Task) + + full := srv.buildWorkspaceFullState(ws, nil) + assert.Equal(t, state.StatusWorking, full.Status) + assert.Empty(t, full.Task) +} + func TestBuildWorkspaceFullStateWithLogLines(t *testing.T) { server, rootDir := setupTestServer(t) wsDir := setupTestWorkspace(t, server, rootDir, "test-ws") @@ -4196,6 +4433,15 @@ func TestResolveCurrentModel(t *testing.T) { assert.Equal(t, "claude-opus-4", result) }) + t.Run("parallelCurrentAgentsSuppressCurrentModel", func(t *testing.T) { + wfState := workflowWith(func(workflow *state.Workflow) { + workflow.CurrentAgent = "go-developer, react-developer" + workflow.CurrentModel = "claude-opus-4" + }) + result := resolveCurrentModel("/some/path", &wfState) + assert.Empty(t, result) + }) + t.Run("noAgent", func(t *testing.T) { wfState := newTestWorkflow() result := resolveCurrentModel("/some/path", &wfState) @@ -4225,25 +4471,6 @@ func TestResolveCurrentModel(t *testing.T) { }) } -func TestResolveCurrentModelNoAgentReturnsEmpty(t *testing.T) { - wf := newTestWorkflow() - result := resolveCurrentModel("/tmp", &wf) - assert.Empty(t, result) -} - -func TestResolveCurrentModelNoModel(t *testing.T) { - dir := t.TempDir() - wf := newTestWorkflow() - result := resolveCurrentModel(dir, &wf) - assert.Empty(t, result) -} - -func TestResolveCurrentModelWithExplicitModel(t *testing.T) { - wf := workflowWith(func(workflow *state.Workflow) { workflow.CurrentModel = "opus-4" }) - result := resolveCurrentModel("/tmp", &wf) - assert.Equal(t, "opus-4", result) -} - func TestSPAMiddlewareStaticAssets(t *testing.T) { srv, _ := setupTestServer(t) mux := http.NewServeMux() @@ -4495,6 +4722,7 @@ func TestHandleRespondViaCoordinator(t *testing.T) { w := httptest.NewRecorder() srv.handleRespondViaCoordinator(w, dir, coord, respondRequestWith(func(request *apiRespondRequest) { + request.PromptToken = "missing-current" request.Answer = "yes" })) assert.Equal(t, http.StatusConflict, w.Code) @@ -4521,6 +4749,7 @@ func TestHandleRespondViaCoordinator(t *testing.T) { w := httptest.NewRecorder() srv.handleRespondViaCoordinator(w, dir, coord, respondRequestWith(func(request *apiRespondRequest) { + request.PromptToken = "missing-current" request.Answer = "yes" })) assert.Equal(t, http.StatusConflict, w.Code) @@ -4943,15 +5172,20 @@ func TestGetWorkspaceStatusPreservesWorkingStatusOnDisk(t *testing.T) { func TestHandleAPIRespondViaHTTP(t *testing.T) { srv, rootDir := setupTestServer(t) wsDir := setupTestWorkspace(t, srv, rootDir, "respond-ws") - _, errCh, cancel := startWaitingSessionQuestion(t, srv, wsDir, multiChoiceQuestionWith(func(question *state.MultiChoiceQuestion) { + coord, errCh, cancel := startWaitingSessionQuestion(t, srv, wsDir, multiChoiceQuestionWith(func(question *state.MultiChoiceQuestion) { question.Questions = []state.QuestionItem{questionItemWith(func(item *state.QuestionItem) { item.Question = "Pick one" item.Choices = []string{"A", "B"} })} }), "Pick one") defer cancel() - body := `{"selectedChoices":["A"]}` - w := serveHTTP(srv, "POST", "/api/v1/workspaces/respond-ws/respond", body) + promptToken := waitForSessionPromptToken(t, coord) + body, errMarshal := json.Marshal(respondRequestWith(func(request *apiRespondRequest) { + request.PromptToken = promptToken + request.SelectedChoices = []string{"A"} + })) + require.NoError(t, errMarshal) + w := serveHTTP(srv, "POST", "/api/v1/workspaces/respond-ws/respond", string(body)) assert.Equal(t, http.StatusOK, w.Code) require.NoError(t, <-errCh) } diff --git a/cmd/sgai/serve_test.go b/cmd/sgai/serve_test.go index 77fa728..4dcce9b 100644 --- a/cmd/sgai/serve_test.go +++ b/cmd/sgai/serve_test.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "testing" "testing/synctest" "time" @@ -22,6 +23,13 @@ func agentSequenceEntryWith(update func(*state.AgentSequenceEntry)) state.AgentS return updated(newTestAgentSequenceEntry(), update) } +func TestInjectCurrentAgentStyleSupportsParallelAgents(t *testing.T) { + dot := "strict digraph G {\n \"go-developer\"\n \"react-developer\"\n}" + result := injectCurrentAgentStyle(dot, "go-developer, react-developer") + assert.Contains(t, result, `"go-developer" [style=filled, fillcolor="#10b981", fontcolor=white]`) + assert.Contains(t, result, `"react-developer" [style=filled, fillcolor="#10b981", fontcolor=white]`) +} + type progressStringCase struct { name string progress []state.ProgressEntry @@ -47,6 +55,14 @@ func timestampProgressEntries(timestamps ...string) []state.ProgressEntry { return entries } +func assertWorkflowSVGContainsAgent(t *testing.T, svg, agent string) { + t.Helper() + assert.True(t, + strings.Contains(svg, agent) || strings.Contains(svg, strings.ReplaceAll(agent, "-", "-")), + "expected workflow SVG to contain %q", agent, + ) +} + func assertWorkspaceChangeInvalidatesWorkspaceCaches(t *testing.T, srv *Server, wsDir string, wfState *state.Workflow, running bool) { t.Helper() @@ -2591,6 +2607,23 @@ retrospective: true } } +func TestGetWorkflowSVGEmptyFlowIncludesParallelCurrentAgents(t *testing.T) { + workspacePath := t.TempDir() + goalData := []byte(`--- +title: Parallel SVG graph +--- +# Test Goal`) + metadata, errParse := parseYAMLFrontmatter(goalData) + require.NoError(t, errParse) + assert.Empty(t, metadata.Flow) + require.NoError(t, os.WriteFile(filepath.Join(workspacePath, "GOAL.md"), goalData, 0o644)) + + svg := getWorkflowSVG(workspacePath, "go-developer, react-developer") + assert.NotEmpty(t, svg) + assertWorkflowSVGContainsAgent(t, svg, "go-developer") + assertWorkflowSVGContainsAgent(t, svg, "react-developer") +} + func TestGetWorkflowSVGCached(t *testing.T) { rootDir := t.TempDir() server := NewServer(rootDir, newTestServerPaths(), "") diff --git a/cmd/sgai/service_session.go b/cmd/sgai/service_session.go index a0803a5..945b851 100644 --- a/cmd/sgai/service_session.go +++ b/cmd/sgai/service_session.go @@ -17,6 +17,7 @@ var ( errRootWorkspaceCannotStart = errors.New("root workspace cannot start agentic work") errSessionResetWhileRunning = errors.New("cannot reset while session is running") errNoPendingQuestion = errors.New("no pending question") + errPromptTokenRequired = errors.New("prompt token is required") errResponseCannotBeEmpty = errors.New("response cannot be empty") errQuestionNotAvailable = errors.New("question not available") errSteerMessageEmpty = errors.New("message cannot be empty") @@ -213,6 +214,10 @@ func (s *Server) resetSessionService(workspacePath string) (resetSessionResult, coord := s.workspaceCoordinator(workspacePath) if errUpdate := coord.UpdateState(func(wf *state.Workflow) { wf.Status = state.StatusComplete + wf.Task = "" + wf.CurrentAgent = "" + wf.CurrentModel = "" + wf.AgentStates = nil }); errUpdate != nil { return resetSessionResult{}, fmt.Errorf("failed to reset state: %w", errUpdate) } @@ -232,6 +237,14 @@ type respondResult struct { Message string } +func requiredPromptToken(promptToken string) (string, error) { + promptToken = strings.TrimSpace(promptToken) + if promptToken == "" { + return "", errPromptTokenRequired + } + return promptToken, nil +} + func (s *Server) respondService(workspacePath, promptToken, answer string, selectedChoices []string) (respondResult, error) { req := apiRespondRequest{ PromptToken: promptToken, @@ -263,7 +276,12 @@ func (s *Server) respondViaCoordinatorService(workspacePath string, coord *state return respondResult{}, errResponseCannotBeEmpty } - if !coord.RespondIfCurrent(req.PromptToken, responseText) { + promptToken, errPromptToken := requiredPromptToken(req.PromptToken) + if errPromptToken != nil { + return respondResult{}, errPromptToken + } + + if !coord.RespondIfCurrent(promptToken, responseText) { return respondResult{}, errQuestionNotAvailable } diff --git a/cmd/sgai/service_session_test.go b/cmd/sgai/service_session_test.go index 493da00..3c13103 100644 --- a/cmd/sgai/service_session_test.go +++ b/cmd/sgai/service_session_test.go @@ -57,7 +57,7 @@ func startCoordinatorQuestion(t *testing.T, coord *state.Coordinator, question * ctx, cancel := context.WithCancel(context.Background()) errCh := make(chan error, 1) go func() { - _, err := coord.AskAndWait(ctx, question, humanMessage) + _, err := coord.AskAndWait(ctx, question, humanMessage, "coordinator") errCh <- err }() <-ready @@ -440,8 +440,10 @@ func TestRespondViaCoordinatorServiceWorkGateApproval(t *testing.T) { question.IsWorkGate = true }), "Approve this definition?") defer cancel() + promptToken := waitForSessionPromptToken(t, coord) req := respondRequestWith(func(request *apiRespondRequest) { + request.PromptToken = promptToken request.SelectedChoices = []string{workGateApprovalText} }) @@ -456,6 +458,35 @@ func TestRespondViaCoordinatorServiceWorkGateApproval(t *testing.T) { }) } +func TestRespondViaCoordinatorServiceRequiresPromptToken(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + rootDir := t.TempDir() + server := NewServer(rootDir, newTestServerPaths(), "") + workspacePath := filepath.Join(rootDir, "test-workspace") + require.NoError(t, os.MkdirAll(filepath.Join(workspacePath, ".sgai"), 0o755)) + + coord := state.NewCoordinatorEmpty(statePath(workspacePath)) + errCh, cancel := startCoordinatorQuestion(t, coord, multiChoiceQuestionWith(func(question *state.MultiChoiceQuestion) { + question.Questions = []state.QuestionItem{questionItemWith(func(item *state.QuestionItem) { + item.Question = "Pick one" + item.Choices = []string{"A", "B"} + })} + }), "Pick one") + + _, err := server.respondViaCoordinatorService(workspacePath, coord, respondRequestWith(func(request *apiRespondRequest) { + request.Answer = "current answer" + })) + require.Error(t, err) + require.ErrorIs(t, err, errPromptTokenRequired) + assert.Contains(t, err.Error(), "prompt token is required") + assert.True(t, coord.State().NeedsHumanInput()) + + cancel() + synctest.Wait() + require.ErrorIs(t, <-errCh, context.Canceled) + }) +} + func TestRespondViaCoordinatorServiceRejectsStalePromptToken(t *testing.T) { synctest.Test(t, func(t *testing.T) { rootDir := t.TempDir() @@ -570,12 +601,13 @@ func TestRespondViaCoordinatorWorkGateApproval(t *testing.T) { question.IsWorkGate = true }), "Is this ready?") defer cancel() + promptToken := waitForSessionPromptToken(t, coord) srv.mu.Lock() srv.sessions[wsDir] = newTestServeSession(coord, false) srv.mu.Unlock() - body := `{"answer":"","selectedChoices":["` + workGateApprovalText + `"]}` + body := `{"promptToken":"` + promptToken + `","answer":"","selectedChoices":["` + workGateApprovalText + `"]}` w := serveHTTP(srv, "POST", "/api/v1/workspaces/respond-gate/respond", body) assert.Equal(t, http.StatusOK, w.Code) synctest.Wait() @@ -619,7 +651,7 @@ func TestStopSessionPublishesSingleReloadAndWorkspaceSignals(t *testing.T) { errCh := make(chan error, 1) go func() { - _, err := coord.AskAndWait(ctx, nil, "question?") + _, err := coord.AskAndWait(ctx, nil, "question?", "coordinator") errCh <- err }() diff --git a/cmd/sgai/webapp/src/__tests__/AppShortcut.test.tsx b/cmd/sgai/webapp/src/__tests__/AppShortcut.test.tsx index 5b45da9..72d8a70 100644 --- a/cmd/sgai/webapp/src/__tests__/AppShortcut.test.tsx +++ b/cmd/sgai/webapp/src/__tests__/AppShortcut.test.tsx @@ -72,7 +72,7 @@ function createMockWorkspace(overrides: Partial = {}): ApiWor events: [], messages: [], projectTodos: [], - agentTodos: [], + agentTodoSections: [], log: [], ...overrides, }; diff --git a/cmd/sgai/webapp/src/components/__tests__/MarkdownEditorShortcut.test.tsx b/cmd/sgai/webapp/src/components/__tests__/MarkdownEditorShortcut.test.tsx index f7e50c2..0cad465 100644 --- a/cmd/sgai/webapp/src/components/__tests__/MarkdownEditorShortcut.test.tsx +++ b/cmd/sgai/webapp/src/components/__tests__/MarkdownEditorShortcut.test.tsx @@ -67,7 +67,7 @@ const mockGetState = mock(() => ({ events: [], messages: [], projectTodos: [], - agentTodos: [], + agentTodoSections: [], log: [], external: false, }, diff --git a/cmd/sgai/webapp/src/pages/WorkspaceDetail.tsx b/cmd/sgai/webapp/src/pages/WorkspaceDetail.tsx index c1f6729..94a4fb9 100644 --- a/cmd/sgai/webapp/src/pages/WorkspaceDetail.tsx +++ b/cmd/sgai/webapp/src/pages/WorkspaceDetail.tsx @@ -826,7 +826,7 @@ function TabContent({ cost={detail.cost} modelStatuses={detail.modelStatuses} projectTodos={detail.projectTodos ?? []} - agentTodos={detail.agentTodos ?? []} + agentTodoSections={detail.agentTodoSections ?? []} pmContent={detail.pmContent} hasProjectMgmt={detail.hasProjectMgmt} /> diff --git a/cmd/sgai/webapp/src/pages/__tests__/AdhocOutput.test.tsx b/cmd/sgai/webapp/src/pages/__tests__/AdhocOutput.test.tsx index b28d49c..a0bbdfa 100644 --- a/cmd/sgai/webapp/src/pages/__tests__/AdhocOutput.test.tsx +++ b/cmd/sgai/webapp/src/pages/__tests__/AdhocOutput.test.tsx @@ -43,7 +43,7 @@ const mockWorkspace = { events: [], messages: [], projectTodos: [], - agentTodos: [], + agentTodoSections: [], log: [], external: false, }; diff --git a/cmd/sgai/webapp/src/pages/__tests__/Dashboard.test.tsx b/cmd/sgai/webapp/src/pages/__tests__/Dashboard.test.tsx index 412127f..c848584 100644 --- a/cmd/sgai/webapp/src/pages/__tests__/Dashboard.test.tsx +++ b/cmd/sgai/webapp/src/pages/__tests__/Dashboard.test.tsx @@ -83,7 +83,7 @@ const createRepositoryAction = (overrides: Record = {}) => ({ events: [], messages: [], projectTodos: [], - agentTodos: [], + agentTodoSections: [], log: [], external: false, ...overrides, @@ -187,7 +187,7 @@ const createMockWorkspace = (overrides: Record = {}) => ( events: [], messages: [], projectTodos: [], - agentTodos: [], + agentTodoSections: [], log: [], external: false, repositoryAction: createRepositoryAction(), diff --git a/cmd/sgai/webapp/src/pages/__tests__/EditGoal.test.tsx b/cmd/sgai/webapp/src/pages/__tests__/EditGoal.test.tsx index 057682b..0027a2e 100644 --- a/cmd/sgai/webapp/src/pages/__tests__/EditGoal.test.tsx +++ b/cmd/sgai/webapp/src/pages/__tests__/EditGoal.test.tsx @@ -50,7 +50,7 @@ const mockWorkspace = { events: [], messages: [], projectTodos: [], - agentTodos: [], + agentTodoSections: [], log: [], external: false, }; diff --git a/cmd/sgai/webapp/src/pages/__tests__/ResponseMultiChoice.test.tsx b/cmd/sgai/webapp/src/pages/__tests__/ResponseMultiChoice.test.tsx index 8b7a5f7..0742a0b 100644 --- a/cmd/sgai/webapp/src/pages/__tests__/ResponseMultiChoice.test.tsx +++ b/cmd/sgai/webapp/src/pages/__tests__/ResponseMultiChoice.test.tsx @@ -69,7 +69,7 @@ const mockWorkspace = { events: [], messages: [], projectTodos: [], - agentTodos: [], + agentTodoSections: [], log: [], external: false, pendingQuestion: mockQuestion, diff --git a/cmd/sgai/webapp/src/pages/__tests__/WorkspaceDetail.test.tsx b/cmd/sgai/webapp/src/pages/__tests__/WorkspaceDetail.test.tsx index d52fba2..058403f 100644 --- a/cmd/sgai/webapp/src/pages/__tests__/WorkspaceDetail.test.tsx +++ b/cmd/sgai/webapp/src/pages/__tests__/WorkspaceDetail.test.tsx @@ -70,7 +70,7 @@ const createRepositoryAction = (overrides: Record = {}) => ({ events: [], messages: [], projectTodos: [], - agentTodos: [], + agentTodoSections: [], log: [], external: false, ...overrides, @@ -174,7 +174,7 @@ const createMockWorkspace = (overrides = {}) => ( events: [], messages: [], projectTodos: [], - agentTodos: [], + agentTodoSections: [], log: [], external: false, repositoryAction: createRepositoryAction(), diff --git a/cmd/sgai/webapp/src/pages/tabs/SessionTab.tsx b/cmd/sgai/webapp/src/pages/tabs/SessionTab.tsx index 6e36aa6..56f0f3c 100644 --- a/cmd/sgai/webapp/src/pages/tabs/SessionTab.tsx +++ b/cmd/sgai/webapp/src/pages/tabs/SessionTab.tsx @@ -9,7 +9,15 @@ import { Button } from "@/components/ui/button"; import { MarkdownContent } from "@/components/MarkdownContent"; import { ChevronRight } from "lucide-react"; import { api } from "@/lib/api"; -import type { ApiAgentCost, ApiDollarBreakdown, ApiStepCost, ApiTodoEntry, ApiSessionCost, ApiTokenUsage } from "@/types"; +import type { + ApiAgentCost, + ApiAgentTodoSection, + ApiDollarBreakdown, + ApiStepCost, + ApiTodoEntry, + ApiSessionCost, + ApiTokenUsage, +} from "@/types"; interface SessionTabProps { workspaceName: string; @@ -17,7 +25,7 @@ interface SessionTabProps { cost?: ApiSessionCost; modelStatuses?: Array<{ modelId: string; status: string }>; projectTodos?: ApiTodoEntry[]; - agentTodos?: ApiTodoEntry[]; + agentTodoSections?: ApiAgentTodoSection[]; pmContent?: string; hasProjectMgmt?: boolean; } @@ -27,6 +35,7 @@ type AgentSequenceEntry = NonNullable[number]; const EMPTY_AGENT_SEQUENCE: NonNullable = []; const EMPTY_MODEL_STATUSES: NonNullable = []; const EMPTY_TODOS: ApiTodoEntry[] = []; +const EMPTY_AGENT_TODO_SECTIONS: ApiAgentTodoSection[] = []; function agentSequenceEntryKey(entry: AgentSequenceEntry, displayedIndex: number, totalEntries: number): string { const sourceOrdinal = totalEntries - displayedIndex - 1; @@ -435,7 +444,44 @@ function TodoList({ todos, emptyMessage }: { todos: ApiTodoEntry[]; emptyMessage ); } -function TasksSection({ projectTodos, agentTodos }: { projectTodos: ApiTodoEntry[]; agentTodos: ApiTodoEntry[] }) { +function agentTodoSectionKey(section: ApiAgentTodoSection, index: number): string { + const agent = typeof section.agent === "string" ? section.agent.trim() : ""; + return agent || `active-agent-${index}`; +} + +function agentTodoSectionTodos(section: ApiAgentTodoSection): ApiTodoEntry[] { + return Array.isArray(section.todos) ? section.todos : EMPTY_TODOS; +} + +function AgentTodoSectionsList({ sections }: { sections: ApiAgentTodoSection[] }) { + if (sections.length === 0) { + return

No active agents

; + } + + return ( +
+ {sections.map((section, index) => { + const agentName = typeof section.agent === "string" && section.agent.trim() ? section.agent.trim() : "Unknown agent"; + + return ( +
0 ? "border-t pt-4" : undefined}> + + +

{agentName}

+
+ {agentName} +
+
+ +
+
+ ); + })} +
+ ); +} + +function TasksSection({ projectTodos, agentTodoSections }: { projectTodos: ApiTodoEntry[]; agentTodoSections: ApiAgentTodoSection[] }) { return (
@@ -452,7 +498,7 @@ function TasksSection({ projectTodos, agentTodos }: { projectTodos: ApiTodoEntry Agent TODO - +
@@ -536,7 +582,7 @@ interface SessionStaticContentProps { cost?: ApiSessionCost; modelStatuses: NonNullable; projectTodos: ApiTodoEntry[]; - agentTodos: ApiTodoEntry[]; + agentTodoSections: ApiAgentTodoSection[]; pmContent?: string; hasProjectMgmt?: boolean; } @@ -546,7 +592,7 @@ const SessionStaticContent = memo(function SessionStaticContent({ cost, modelStatuses, projectTodos, - agentTodos, + agentTodoSections, pmContent, hasProjectMgmt, }: SessionStaticContentProps) { @@ -557,7 +603,7 @@ const SessionStaticContent = memo(function SessionStaticContent({ Tasks - + @@ -651,7 +697,7 @@ export function SessionTab({ cost, modelStatuses, projectTodos, - agentTodos, + agentTodoSections, pmContent, hasProjectMgmt, }: SessionTabProps) { @@ -663,7 +709,7 @@ export function SessionTab({ cost={cost} modelStatuses={modelStatuses ?? EMPTY_MODEL_STATUSES} projectTodos={projectTodos ?? EMPTY_TODOS} - agentTodos={agentTodos ?? EMPTY_TODOS} + agentTodoSections={agentTodoSections ?? EMPTY_AGENT_TODO_SECTIONS} pmContent={pmContent} hasProjectMgmt={hasProjectMgmt} /> diff --git a/cmd/sgai/webapp/src/pages/tabs/__tests__/EventsTab.test.tsx b/cmd/sgai/webapp/src/pages/tabs/__tests__/EventsTab.test.tsx index 1683677..e12b358 100644 --- a/cmd/sgai/webapp/src/pages/tabs/__tests__/EventsTab.test.tsx +++ b/cmd/sgai/webapp/src/pages/tabs/__tests__/EventsTab.test.tsx @@ -57,7 +57,7 @@ const createMockWorkspace = (overrides = {}) => ({ events: [], messages: [], projectTodos: [], - agentTodos: [], + agentTodoSections: [], log: [], external: false, ...overrides, diff --git a/cmd/sgai/webapp/src/pages/tabs/__tests__/ForksTab.test.tsx b/cmd/sgai/webapp/src/pages/tabs/__tests__/ForksTab.test.tsx index 24ab882..026b6a9 100644 --- a/cmd/sgai/webapp/src/pages/tabs/__tests__/ForksTab.test.tsx +++ b/cmd/sgai/webapp/src/pages/tabs/__tests__/ForksTab.test.tsx @@ -52,7 +52,7 @@ const createRepositoryAction = (overrides: Record = {}) => ({ events: [], messages: [], projectTodos: [], - agentTodos: [], + agentTodoSections: [], forks: [], log: [], actions: [], @@ -158,7 +158,7 @@ const createMockWorkspace = (overrides: Record = {}) => ( events: [], messages: [], projectTodos: [], - agentTodos: [], + agentTodoSections: [], forks: [], log: [], actions: [], diff --git a/cmd/sgai/webapp/src/pages/tabs/__tests__/SessionTab.test.tsx b/cmd/sgai/webapp/src/pages/tabs/__tests__/SessionTab.test.tsx index 8e46a8c..e130cff 100644 --- a/cmd/sgai/webapp/src/pages/tabs/__tests__/SessionTab.test.tsx +++ b/cmd/sgai/webapp/src/pages/tabs/__tests__/SessionTab.test.tsx @@ -40,6 +40,15 @@ const createDollarBreakdown = (overrides = {}) => ({ ...overrides, }); +type TestTodoEntry = { + id: string; + content: string; + status: string; + priority: string; +}; + +const createAgentTodoSection = (agent: string, todos: TestTodoEntry[] = []) => ({ agent, todos }); + const createMockWorkspace = (overrides = {}) => ({ name: "test-workspace", dir: "/path/to/test-workspace", @@ -80,7 +89,7 @@ const createMockWorkspace = (overrides = {}) => ({ events: [], messages: [], projectTodos: [], - agentTodos: [], + agentTodoSections: [], log: [], external: false, ...overrides, @@ -131,7 +140,7 @@ function renderSessionTab(props = {}, { strictMode = false }: { strictMode?: boo cost: workspace.cost, modelStatuses: workspace.modelStatuses, projectTodos: workspace.projectTodos, - agentTodos: workspace.agentTodos, + agentTodoSections: workspace.agentTodoSections, pmContent: workspace.pmContent as string | undefined, hasProjectMgmt: workspace.hasProjectMgmt, ...props, @@ -330,7 +339,7 @@ describe("SessionTab", () => { agentSequence={[{ agent: "coordinator", model: "opencode/glm-5", elapsedTime: "1m", isCurrent: true }]} modelStatuses={[{ modelId: "opencode/glm-5", status: "model-done" }]} projectTodos={[{ id: "todo-1", content: "Todo 1", status: "pending", priority: "medium" }]} - agentTodos={[]} + agentTodoSections={[]} /> @@ -364,11 +373,11 @@ describe("SessionTab", () => { }); }); - it("shows empty message when no agent todos", async () => { + it("shows empty message when there are no active agents", async () => { renderSessionTab(); await waitFor(() => { - expect(screen.getByText("No active agent todos")).toBeTruthy(); + expect(screen.getByText("No active agents")).toBeTruthy(); }); }); @@ -388,27 +397,43 @@ describe("SessionTab", () => { }); }); - it("displays agent todos", async () => { + it("renders grouped agent todo sections in active-agent order and shows per-agent empty states", async () => { mockWorkspaces = [createMockWorkspace({ - agentTodos: [ - { id: "1", content: "Write tests", status: "completed", priority: "high" }, + agentTodoSections: [ + createAgentTodoSection("coordinator", [ + { id: "1", content: "Write tests", status: "completed", priority: "high" }, + ]), + createAgentTodoSection("react-developer"), + createAgentTodoSection("go-developer", [ + { id: "2", content: "Update API contract", status: "pending", priority: "medium" }, + ]), ], })]; renderSessionTab(); await waitFor(() => { + expect(screen.getByRole("heading", { level: 3, name: "coordinator" })).toBeTruthy(); + expect(screen.getByRole("heading", { level: 3, name: "react-developer" })).toBeTruthy(); + expect(screen.getByRole("heading", { level: 3, name: "go-developer" })).toBeTruthy(); expect(screen.getByText("Write tests")).toBeTruthy(); + expect(screen.getByText("No active TODOs for react-developer")).toBeTruthy(); + expect(screen.getByText("Update API contract")).toBeTruthy(); }); + + const agentSectionHeadings = screen.getAllByRole("heading", { level: 3 }).map((heading) => heading.textContent); + expect(agentSectionHeadings).toEqual(["coordinator", "react-developer", "go-developer"]); }); it("does not emit duplicate-key warnings when persisted agent todos are missing ids", async () => { const consoleErrorSpy = spyOn(console, "error").mockImplementation(() => {}); mockWorkspaces = [createMockWorkspace({ - agentTodos: [ - { id: "", content: "First todo", status: "pending", priority: "high" }, - { id: "", content: "Second todo", status: "pending", priority: "medium" }, + agentTodoSections: [ + createAgentTodoSection("react-developer", [ + { id: "", content: "First todo", status: "pending", priority: "high" }, + { id: "", content: "Second todo", status: "pending", priority: "medium" }, + ]), ], })]; @@ -428,8 +453,10 @@ describe("SessionTab", () => { it("keeps the original blank-id agent todo row mounted when a same-signature blank-id todo is inserted ahead of it", async () => { const view = renderSessionTab({ - agentTodos: [ - { id: "", content: "Duplicate todo", status: "pending", priority: "medium" }, + agentTodoSections: [ + createAgentTodoSection("react-developer", [ + { id: "", content: "Duplicate todo", status: "pending", priority: "medium" }, + ]), ], }); @@ -448,9 +475,11 @@ describe("SessionTab", () => { cost={createMockWorkspace().cost} modelStatuses={[]} projectTodos={[]} - agentTodos={[ - { id: "", content: "Duplicate todo", status: "pending", priority: "medium" }, - { id: "", content: "Duplicate todo", status: "pending", priority: "medium" }, + agentTodoSections={[ + createAgentTodoSection("react-developer", [ + { id: "", content: "Duplicate todo", status: "pending", priority: "medium" }, + { id: "", content: "Duplicate todo", status: "pending", priority: "medium" }, + ]), ]} /> @@ -472,8 +501,10 @@ describe("SessionTab", () => { it("keeps the original blank-id agent todo row mounted when its mutable display fields change", async () => { const view = renderSessionTab({ - agentTodos: [ - { id: "", content: "Original todo", status: "pending", priority: "medium" }, + agentTodoSections: [ + createAgentTodoSection("react-developer", [ + { id: "", content: "Original todo", status: "pending", priority: "medium" }, + ]), ], }); @@ -492,8 +523,10 @@ describe("SessionTab", () => { cost={createMockWorkspace().cost} modelStatuses={[]} projectTodos={[]} - agentTodos={[ - { id: "", content: "Updated todo", status: "in_progress", priority: "high" }, + agentTodoSections={[ + createAgentTodoSection("react-developer", [ + { id: "", content: "Updated todo", status: "in_progress", priority: "high" }, + ]), ]} /> @@ -513,8 +546,10 @@ describe("SessionTab", () => { it("keeps the original blank-id agent todo row mounted under StrictMode when a same-signature blank-id todo is inserted ahead of it", async () => { const view = renderSessionTab({ - agentTodos: [ - { id: "", content: "Duplicate todo", status: "pending", priority: "medium" }, + agentTodoSections: [ + createAgentTodoSection("react-developer", [ + { id: "", content: "Duplicate todo", status: "pending", priority: "medium" }, + ]), ], }, { strictMode: true }); @@ -533,9 +568,11 @@ describe("SessionTab", () => { cost={createMockWorkspace().cost} modelStatuses={[]} projectTodos={[]} - agentTodos={[ - { id: "", content: "Duplicate todo", status: "pending", priority: "medium" }, - { id: "", content: "Duplicate todo", status: "pending", priority: "medium" }, + agentTodoSections={[ + createAgentTodoSection("react-developer", [ + { id: "", content: "Duplicate todo", status: "pending", priority: "medium" }, + { id: "", content: "Duplicate todo", status: "pending", priority: "medium" }, + ]), ]} /> @@ -553,8 +590,10 @@ describe("SessionTab", () => { it("keeps the original blank-id agent todo row mounted under StrictMode when its mutable display fields change", async () => { const view = renderSessionTab({ - agentTodos: [ - { id: "", content: "Original todo", status: "pending", priority: "medium" }, + agentTodoSections: [ + createAgentTodoSection("react-developer", [ + { id: "", content: "Original todo", status: "pending", priority: "medium" }, + ]), ], }, { strictMode: true }); @@ -573,8 +612,10 @@ describe("SessionTab", () => { cost={createMockWorkspace().cost} modelStatuses={[]} projectTodos={[]} - agentTodos={[ - { id: "", content: "Updated todo", status: "in_progress", priority: "high" }, + agentTodoSections={[ + createAgentTodoSection("react-developer", [ + { id: "", content: "Updated todo", status: "in_progress", priority: "high" }, + ]), ]} /> @@ -594,7 +635,7 @@ describe("SessionTab", () => { }); it("ignores abandoned StrictMode duplicate-insertion renders when later committing a mutable blank-id todo update", async () => { - const renderTree = (agentTodos: Array<{ id: string; content: string; status: string; priority: string }>, shouldSuspend: boolean) => ( + const renderTree = (agentTodoSections: Array<{ agent: string; todos: TestTodoEntry[] }>, shouldSuspend: boolean) => ( maybeWrapStrictMode( @@ -605,7 +646,7 @@ describe("SessionTab", () => { cost={createMockWorkspace().cost} modelStatuses={[]} projectTodos={[]} - agentTodos={agentTodos} + agentTodoSections={agentTodoSections} /> @@ -616,7 +657,9 @@ describe("SessionTab", () => { ); const view = render(renderTree([ - { id: "", content: "Original todo", status: "pending", priority: "medium" }, + createAgentTodoSection("react-developer", [ + { id: "", content: "Original todo", status: "pending", priority: "medium" }, + ]), ], false)); const originalRow = await waitFor(() => { @@ -627,8 +670,10 @@ describe("SessionTab", () => { startTransition(() => { view.rerender(renderTree([ - { id: "", content: "Original todo", status: "pending", priority: "medium" }, - { id: "", content: "Original todo", status: "pending", priority: "medium" }, + createAgentTodoSection("react-developer", [ + { id: "", content: "Original todo", status: "pending", priority: "medium" }, + { id: "", content: "Original todo", status: "pending", priority: "medium" }, + ]), ], true)); }); @@ -637,7 +682,9 @@ describe("SessionTab", () => { }); view.rerender(renderTree([ - { id: "", content: "Updated todo", status: "in_progress", priority: "high" }, + createAgentTodoSection("react-developer", [ + { id: "", content: "Updated todo", status: "in_progress", priority: "high" }, + ]), ], false)); const updatedRow = await waitFor(() => { @@ -846,7 +893,7 @@ describe("SessionTab", () => { cost={createMockWorkspace().cost} modelStatuses={[]} projectTodos={[]} - agentTodos={[]} + agentTodoSections={[]} /> @@ -884,7 +931,7 @@ describe("SessionTab", () => { cost={createMockWorkspace().cost} modelStatuses={[]} projectTodos={[]} - agentTodos={[]} + agentTodoSections={[]} /> diff --git a/cmd/sgai/webapp/src/types/index.ts b/cmd/sgai/webapp/src/types/index.ts index 14ef06d..e9261e4 100644 --- a/cmd/sgai/webapp/src/types/index.ts +++ b/cmd/sgai/webapp/src/types/index.ts @@ -74,7 +74,7 @@ export interface ApiWorkspaceEntry { events: ApiEventEntry[]; messages: ApiMessageEntry[]; projectTodos: ApiTodoEntry[]; - agentTodos: ApiTodoEntry[]; + agentTodoSections: ApiAgentTodoSection[]; forks?: ApiForkEntry[]; log: ApiLogEntry[]; pendingQuestion?: ApiPendingQuestionResponse; @@ -281,6 +281,11 @@ export interface ApiTodoEntry { priority: string; } +export interface ApiAgentTodoSection { + agent: string; + todos: ApiTodoEntry[]; +} + export interface ApiLogEntry { prefix: string; text: string; diff --git a/cmd/sgai/workflow_batch_state.go b/cmd/sgai/workflow_batch_state.go new file mode 100644 index 0000000..35f5dcd --- /dev/null +++ b/cmd/sgai/workflow_batch_state.go @@ -0,0 +1,91 @@ +package main + +import ( + "slices" + + "github.com/ucirello/sgai/pkg/state" +) + +func prepareCurrentBatchState(wf *state.Workflow, currentAgents []string) { + if len(currentAgents) < 2 { + wf.AgentStates = nil + wf.Status = state.StatusWorking + wf.Task = "" + return + } + + wf.AgentStates = make(map[string]state.AgentExecutionState, len(currentAgents)) + for _, currentAgent := range currentAgents { + wf.AgentStates[currentAgent] = state.AgentExecutionState{Status: state.StatusWorking, Task: ""} + } + wf.Status = parallelBatchStatus(wf) + wf.Task = "" +} + +func updateAgentWorkflowState(wf *state.Workflow, agent, status string, updateStatus bool, task string, updateTask bool) { + if !parallelBatchIncludesAgent(wf, agent) { + if updateStatus { + wf.Status = status + } + if updateTask { + wf.Task = task + } + if workflowStatusClearsTask(wf.Status) { + wf.Task = "" + } + return + } + + if wf.AgentStates == nil { + wf.AgentStates = make(map[string]state.AgentExecutionState, len(splitCurrentAgents(wf.CurrentAgent))) + } + agentState := wf.AgentStates[agent] + if updateStatus { + agentState.Status = status + } + if updateTask { + agentState.Task = task + } + if workflowStatusClearsTask(agentState.Status) { + agentState.Task = "" + } + wf.AgentStates[agent] = agentState + wf.Status = parallelBatchStatus(wf) + wf.Task = "" +} + +func visibleWorkflowStatus(wf *state.Workflow) string { + if !hasParallelCurrentAgents(wf.CurrentAgent) { + return wf.Status + } + return parallelBatchStatus(wf) +} + +func visibleWorkflowTask(wf *state.Workflow) string { + if hasParallelCurrentAgents(wf.CurrentAgent) { + return "" + } + return wf.Task +} + +func parallelBatchStatus(wf *state.Workflow) string { + currentAgents := splitCurrentAgents(wf.CurrentAgent) + if len(currentAgents) < 2 { + return wf.Status + } + for _, currentAgent := range currentAgents { + if wf.AgentStates[currentAgent].Status != state.StatusAgentDone { + return state.StatusWorking + } + } + return state.StatusAgentDone +} + +func parallelBatchIncludesAgent(wf *state.Workflow, agent string) bool { + currentAgents := splitCurrentAgents(wf.CurrentAgent) + return len(currentAgents) > 1 && slices.Contains(currentAgents, agent) +} + +func workflowStatusClearsTask(status string) bool { + return status == state.StatusAgentDone || status == state.StatusComplete +} diff --git a/cmd/sgai/workflow_runner.go b/cmd/sgai/workflow_runner.go index 36e2701..54ef4da 100644 --- a/cmd/sgai/workflow_runner.go +++ b/cmd/sgai/workflow_runner.go @@ -10,6 +10,7 @@ import ( "path/filepath" "slices" "strings" + "sync" "github.com/ucirello/sgai/pkg/state" ) @@ -28,7 +29,9 @@ type workflowRunner struct { logWriter io.Writer retroLogs retroLogWriters iterationCounter int + iterationMu sync.Mutex previousAgent string + runAgentFn func(context.Context, string) state.Workflow } type retroLogWriters struct { @@ -44,6 +47,37 @@ const ( resultInterrupt ) +func nextRunnableAgents(messages []state.Message) []string { + hasUnreadMessages := false + var recipients []string + for _, msg := range messages { + if msg.Read { + continue + } + hasUnreadMessages = true + recipient := extractAgentFromModelID(msg.ToAgent) + if recipient == "coordinator" { + return []string{"coordinator"} + } + if !slices.Contains(recipients, recipient) { + recipients = append(recipients, recipient) + } + } + if !hasUnreadMessages { + return []string{"coordinator"} + } + return recipients +} + +func hasUnreadMessages(messages []state.Message) bool { + for _, msg := range messages { + if !msg.Read { + return true + } + } + return false +} + func (r *workflowRunner) run(ctx context.Context) { for { if ctx.Err() != nil { @@ -51,8 +85,19 @@ func (r *workflowRunner) run(ctx context.Context) { return } - currentAgent := r.resolveCurrentAgent() - result := r.runAgent(ctx, currentAgent) + var errReload error + r.metadata, errReload = tryReloadGoalMetadata(r.goalPath, &r.metadata, r.flowDag) + if errReload != nil { + log.Println("failed to reload GOAL.md frontmatter:", errReload) + return + } + + currentAgents := nextRunnableAgents(r.coord.State().Messages) + if errPrepare := r.prepareAgents(currentAgents); errPrepare != nil { + log.Println("failed to prepare agent batch:", errPrepare) + return + } + result := r.runAgents(ctx, currentAgents) switch result { case resultInterrupt: @@ -64,99 +109,100 @@ func (r *workflowRunner) run(ctx context.Context) { } } -func (r *workflowRunner) resolveCurrentAgent() string { - if r.wfState.CurrentAgent == "" { - return "coordinator" - } - return r.wfState.CurrentAgent +func (r *workflowRunner) nextIteration() int { + r.iterationMu.Lock() + defer r.iterationMu.Unlock() + r.iterationCounter++ + return r.iterationCounter } -func (r *workflowRunner) runAgent(ctx context.Context, currentAgent string) runResult { - if errPrepare := r.prepareAgent(currentAgent); errPrepare != nil { - log.Println("failed to prepare agent:", errPrepare) - return resultInterrupt - } +func (r *workflowRunner) prepareAgents(currentAgents []string) error { + displayAgent := formatCurrentAgents(currentAgents) + r.wfState = r.coord.State() - var errReload error - r.metadata, errReload = tryReloadGoalMetadata(r.goalPath, &r.metadata, r.flowDag) - if errReload != nil { - log.Println("failed to reload GOAL.md frontmatter:", errReload) - return resultInterrupt + if r.previousAgent != "" && r.previousAgent != displayAgent { + log.Println("["+r.paddedsgai+"]", r.previousAgent, "->", displayAgent) + r.wfState.Todos = nil + if errOverlay := applyLayerFolderOverlay(r.dir); errOverlay != nil { + return fmt.Errorf("apply overlay on agent transition: %w", errOverlay) + } } + r.previousAgent = displayAgent - if errUnlock := unlockInteractiveForRetrospective(&r.wfState, currentAgent, r.coord, r.paddedsgai); errUnlock != nil { - log.Println("failed to unlock retrospective interaction mode:", errUnlock) - return resultInterrupt + if r.wfState.VisitCounts == nil { + r.wfState.VisitCounts = map[string]int{} } - r.wfState = r.executeAgent(ctx, currentAgent) - - if ctx.Err() != nil { - return resultInterrupt + r.wfState.CurrentAgent = displayAgent + if len(currentAgents) != 1 || extractAgentFromModelID(r.wfState.CurrentModel) != currentAgents[0] { + r.wfState.CurrentModel = "" + } + if len(currentAgents) == 1 && currentAgents[0] != "coordinator" { + setVisibleAgentTodos(&r.wfState, currentAgents[0]) + } else { + r.wfState.Todos = nil } + prepareCurrentBatchState(&r.wfState, currentAgents) + for _, currentAgent := range currentAgents { + r.wfState.VisitCounts[currentAgent]++ + addAgentHandoffProgress(&r.wfState, currentAgent) + } + markCurrentAgentsInSequence(&r.wfState, currentAgents) - if r.wfState.Status == state.StatusComplete { - redirected, errRedirect := redirectToPendingMessageAgent(&r.wfState, r.coord, r.paddedsgai) - if errRedirect != nil { - log.Println("failed to redirect to pending message agent:", errRedirect) - return resultInterrupt - } - if redirected { - return resultContinue - } - log.Println("["+r.paddedsgai+"]", "complete:", r.wfState.Task) - return resultComplete + if errReplace := r.coord.ReplaceState(&r.wfState); errReplace != nil { + return fmt.Errorf("save state: %w", errReplace) } - nextAgent := r.resolveNextAgent(currentAgent) - r.wfState.CurrentAgent = nextAgent - return resultContinue + return nil } -func (r *workflowRunner) resolveNextAgent(currentAgent string) string { - pendingAgent := findFirstPendingMessageAgent(r.wfState.Messages) - if pendingAgent != "" { - log.Println("["+r.paddedsgai+"]", "pending messages for", pendingAgent, "- redirecting") - return pendingAgent +func (r *workflowRunner) runAgents(ctx context.Context, currentAgents []string) runResult { + if len(currentAgents) == 1 { + r.wfState = r.executeCurrentAgent(ctx, currentAgents[0]) + return r.finishCurrentBatch(ctx) } - if r.flowDag.isTerminal(currentAgent) { - log.Println("["+r.paddedsgai+"]", "reached terminal node", currentAgent) - return "coordinator" + var wg sync.WaitGroup + for _, currentAgent := range currentAgents { + wg.Go(func() { + r.executeCurrentAgent(ctx, currentAgent) + }) } + wg.Wait() + r.wfState = r.coord.State() + return r.finishCurrentBatch(ctx) +} - if currentAgent == "coordinator" && len(r.flowDag.EntryNodes) > 0 { - return r.flowDag.EntryNodes[0] +func (r *workflowRunner) executeCurrentAgent(ctx context.Context, currentAgent string) state.Workflow { + if r.runAgentFn != nil { + return r.runAgentFn(ctx, currentAgent) } - - return determineNextAgent(r.flowDag, currentAgent) + return r.executeAgent(ctx, currentAgent) } -func (r *workflowRunner) prepareAgent(currentAgent string) error { - if r.previousAgent != "" && r.previousAgent != currentAgent { - log.Println("["+r.paddedsgai+"]", r.previousAgent, "->", currentAgent) - r.wfState.Todos = []state.TodoItem{} - if errOverlay := applyLayerFolderOverlay(r.dir); errOverlay != nil { - return fmt.Errorf("apply overlay on agent transition: %w", errOverlay) - } +func (r *workflowRunner) finishCurrentBatch(ctx context.Context) runResult { + if ctx.Err() != nil { + return resultInterrupt } - r.previousAgent = currentAgent - - r.wfState.CurrentAgent = currentAgent - r.wfState.VisitCounts[currentAgent]++ - addAgentHandoffProgress(&r.wfState, currentAgent) - markCurrentAgentInSequence(&r.wfState, currentAgent) - - snapshot := r.wfState - if errUpdate := r.coord.UpdateState(func(wf *state.Workflow) { - *wf = snapshot - }); errUpdate != nil { - return fmt.Errorf("save state: %w", errUpdate) + r.wfState = r.coord.State() + if r.wfState.Status != state.StatusComplete { + return resultContinue } - - return nil + if hasUnreadMessages(r.wfState.Messages) { + if errUpdate := r.coord.UpdateState(func(wf *state.Workflow) { + wf.Status = state.StatusWorking + }); errUpdate != nil { + log.Println("failed to continue workflow after complete with unread messages:", errUpdate) + return resultInterrupt + } + r.wfState = r.coord.State() + return resultContinue + } + log.Println("["+r.paddedsgai+"]", "complete:", r.wfState.Task) + return resultComplete } func (r *workflowRunner) executeAgent(ctx context.Context, currentAgent string) state.Workflow { + metadata := r.metadata cfg := multiModelConfig{ dir: r.dir, goalPath: r.goalPath, @@ -171,8 +217,13 @@ func (r *workflowRunner) executeAgent(ctx context.Context, currentAgent string) logWriter: r.logWriter, stdoutLog: r.retroLogs.stdout, stderrLog: r.retroLogs.stderr, + nextIteration: r.nextIteration, + } + currentState := r.coord.State() + if errUnlock := unlockInteractiveForRetrospective(¤tState, currentAgent, r.coord, r.paddedsgai); errUnlock != nil { + return failWorkflowState(&cfg, ¤tState, "failed to unlock retrospective interaction mode: %v", errUnlock) } - return runMultiModelAgent(ctx, &cfg, &r.wfState, &r.metadata, &r.iterationCounter) + return runMultiModelAgent(ctx, &cfg, ¤tState, &metadata) } func (r *workflowRunner) runContinuous(ctx context.Context, continuousPrompt string) { @@ -363,7 +414,9 @@ func buildWorkflowRunner(dir, mcpURL string, logWriter io.Writer, sessionCoord * logWriter: logWriter, retroLogs: retroLogs, iterationCounter: 0, + iterationMu: sync.Mutex{}, previousAgent: "", + runAgentFn: nil, } return runner, cleanup, nil } diff --git a/cmd/sgai/workflow_runner_test.go b/cmd/sgai/workflow_runner_test.go index d2e965a..30f8013 100644 --- a/cmd/sgai/workflow_runner_test.go +++ b/cmd/sgai/workflow_runner_test.go @@ -8,6 +8,7 @@ import ( "os/exec" "path/filepath" "testing" + "testing/synctest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -84,69 +85,111 @@ func TestFreshWorkflowState(t *testing.T) { assert.Equal(t, want, freshWorkflowState(allAgents, preservedMode)) } -func TestResolveCurrentAgent(t *testing.T) { - t.Run("emptyDefaultsToCoordinator", func(t *testing.T) { - r := testWorkflowRunner() - r.wfState.CurrentAgent = "" - assert.Equal(t, "coordinator", r.resolveCurrentAgent()) - }) - - t.Run("returnsCurrentAgent", func(t *testing.T) { - r := testWorkflowRunner() - r.wfState.CurrentAgent = "builder" - assert.Equal(t, "builder", r.resolveCurrentAgent()) - }) -} - -func buildTestDag(edges map[string][]string, entryNodes []string) *dag { - d := &dag{ - Nodes: make(map[string]*dagNode), - EntryNodes: entryNodes, +func TestNextRunnableAgents(t *testing.T) { + tests := []struct { + name string + messages []state.Message + want []string + }{ + { + name: "noUnreadMessagesReturnsCoordinator", + messages: nil, + want: []string{"coordinator"}, + }, + { + name: "coordinatorUnreadMessageWins", + messages: []state.Message{ + updated(newTestMessage(), func(message *state.Message) { + message.ToAgent = "go-developer" + }), + updated(newTestMessage(), func(message *state.Message) { + message.ToAgent = "coordinator" + }), + }, + want: []string{"coordinator"}, + }, + { + name: "singleUnreadRecipientReturnsThatAgent", + messages: []state.Message{ + updated(newTestMessage(), func(message *state.Message) { + message.ToAgent = "go-developer" + }), + }, + want: []string{"go-developer"}, + }, + { + name: "multipleUnreadRecipientsReturnUniqueAgentsInMessageOrder", + messages: []state.Message{ + updated(newTestMessage(), func(message *state.Message) { + message.ToAgent = "go-developer" + }), + updated(newTestMessage(), func(message *state.Message) { + message.ToAgent = "react-developer" + }), + updated(newTestMessage(), func(message *state.Message) { + message.ToAgent = "go-developer" + }), + }, + want: []string{"go-developer", "react-developer"}, + }, + { + name: "modelRecipientsCollapseToUniqueTopLevelAgents", + messages: []state.Message{ + updated(newTestMessage(), func(message *state.Message) { + message.ToAgent = "project-critic-council:model-a" + }), + updated(newTestMessage(), func(message *state.Message) { + message.ToAgent = "project-critic-council:model-b" + }), + updated(newTestMessage(), func(message *state.Message) { + message.ToAgent = "retrospective" + }), + }, + want: []string{"project-critic-council", "retrospective"}, + }, } - for from, toList := range edges { - node := d.ensureNode(from) - for _, to := range toList { - toNode := d.ensureNode(to) - node.Successors = append(node.Successors, to) - toNode.Predecessors = append(toNode.Predecessors, from) - } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, nextRunnableAgents(tt.messages)) + }) } - return d } -func TestResolveNextAgent(t *testing.T) { - t.Run("redirectsToPendingMessages", func(t *testing.T) { - r := testWorkflowRunner() - r.paddedsgai = "test" - r.flowDag = buildTestDag(map[string][]string{"coordinator": {"reviewer"}}, []string{"coordinator"}) - message := testStateMessage() - message.ID = 1 - message.FromAgent = "coordinator" - message.ToAgent = "reviewer" - message.Body = "review please" - r.wfState.Messages = []state.Message{message} - got := r.resolveNextAgent("coordinator") - assert.Equal(t, "reviewer", got) - }) +func TestRunAgentsRunsParallelRecipients(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + statePath := filepath.Join(t.TempDir(), "state.json") + coord, errCoord := state.NewCoordinatorWith(statePath, testWorkflowState()) + require.NoError(t, errCoord) - t.Run("terminalNodeReturnsCoordinator", func(t *testing.T) { r := testWorkflowRunner() - r.paddedsgai = "test" - r.flowDag = buildTestDag(map[string][]string{"coordinator": {"reviewer"}}, []string{"coordinator"}) - got := r.resolveNextAgent("reviewer") - assert.Equal(t, "coordinator", got) - }) + r.coord = coord - t.Run("coordinatorGoesToFirstEntry", func(t *testing.T) { - r := testWorkflowRunner() - r.paddedsgai = "test" - r.flowDag = buildTestDag(map[string][]string{"coordinator": {"builder"}}, []string{"builder"}) - got := r.resolveNextAgent("coordinator") - assert.Equal(t, "builder", got) + started := make(chan string, 2) + release := make(chan struct{}) + r.runAgentFn = func(_ context.Context, currentAgent string) state.Workflow { + started <- currentAgent + <-release + workflow := testWorkflowState() + workflow.Status = state.StatusAgentDone + return workflow + } + + resultCh := make(chan runResult, 1) + go func() { + resultCh <- r.runAgents(context.Background(), []string{"go-developer", "react-developer"}) + }() + + synctest.Wait() + assert.ElementsMatch(t, []string{"go-developer", "react-developer"}, []string{<-started, <-started}) + + close(release) + synctest.Wait() + assert.Equal(t, resultContinue, <-resultCh) }) } -func TestPrepareAgent(t *testing.T) { +func TestPrepareAgents(t *testing.T) { dir := t.TempDir() sgaiDir := filepath.Join(dir, ".sgai") require.NoError(t, os.MkdirAll(sgaiDir, 0o755)) @@ -163,19 +206,49 @@ func TestPrepareAgent(t *testing.T) { r.coord = coord r.wfState.Status = state.StatusWorking - require.NoError(t, r.prepareAgent("coordinator")) + require.NoError(t, r.prepareAgents([]string{"coordinator"})) assert.Equal(t, "coordinator", r.previousAgent) assert.Equal(t, "coordinator", r.wfState.CurrentAgent) assert.Equal(t, 1, r.wfState.VisitCounts["coordinator"]) - require.NoError(t, r.prepareAgent("builder")) + require.NoError(t, r.prepareAgents([]string{"builder"})) assert.Equal(t, "builder", r.previousAgent) assert.Equal(t, "builder", r.wfState.CurrentAgent) assert.Equal(t, 1, r.wfState.VisitCounts["builder"]) assert.Empty(t, r.wfState.Todos) } -func TestPrepareAgentReturnsStateSaveError(t *testing.T) { +func TestPrepareAgentsDetachesSavedWorkflowSnapshot(t *testing.T) { + dir := t.TempDir() + sgaiDir := filepath.Join(dir, ".sgai") + require.NoError(t, os.MkdirAll(sgaiDir, 0o755)) + + statePath := filepath.Join(sgaiDir, "state.json") + initial := state.NewWorkflow() + initial.Status = state.StatusWorking + initial.Progress = []state.ProgressEntry{{Timestamp: "", Agent: "", Description: "stable progress"}} + initial.TodosByAgent["builder"] = []state.TodoItem{{ID: "todo-1", Content: "stable todo", Status: "pending", Priority: "high"}} + coord, errCoord := state.NewCoordinatorWith(statePath, initial) + require.NoError(t, errCoord) + + r := testWorkflowRunner() + r.dir = dir + r.paddedsgai = "test" + r.coord = coord + r.wfState = state.NewWorkflow() + + require.NoError(t, r.prepareAgents([]string{"builder"})) + saved := coord.State() + + r.wfState.Progress[0].Description = "mutated progress" + r.wfState.VisitCounts["builder"] = 99 + r.wfState.Todos[0].Content = "mutated visible todo" + r.wfState.TodosByAgent["builder"][0].Content = "mutated grouped todo" + + assert.Equal(t, saved, coord.State()) +} + +func TestPrepareAgentsReturnsStateSaveError(t *testing.T) { dir := t.TempDir() sgaiDir := filepath.Join(dir, ".sgai") require.NoError(t, os.MkdirAll(sgaiDir, 0o755)) @@ -191,7 +264,7 @@ func TestPrepareAgentReturnsStateSaveError(t *testing.T) { r.paddedsgai = "test" r.coord = coord - errPrepare := r.prepareAgent("coordinator") + errPrepare := r.prepareAgents([]string{"coordinator"}) require.Error(t, errPrepare) require.ErrorContains(t, errPrepare, "directory") } @@ -257,7 +330,7 @@ func TestHandleTrigger(t *testing.T) { }) } -func TestPrepareAgentReappliesOverlayWithoutSkeletonUnpack(t *testing.T) { +func TestPrepareAgentsReappliesOverlayWithoutSkeletonUnpack(t *testing.T) { dir := t.TempDir() statePath := filepath.Join(dir, ".sgai", "state.json") require.NoError(t, os.MkdirAll(filepath.Join(dir, ".sgai", "agent"), 0o755)) @@ -274,7 +347,7 @@ func TestPrepareAgentReappliesOverlayWithoutSkeletonUnpack(t *testing.T) { r.paddedsgai = "test" r.previousAgent = "coordinator" - err := r.prepareAgent("builder") + err := r.prepareAgents([]string{"builder"}) require.NoError(t, err) coordinatorContent, errRead := os.ReadFile(filepath.Join(dir, ".sgai", "agent", "coordinator.md")) @@ -286,7 +359,7 @@ func TestPrepareAgentReappliesOverlayWithoutSkeletonUnpack(t *testing.T) { assert.Equal(t, "handoff overlay", string(overlayContent)) } -func TestRunAgentInterruptsWhenOverlayRefreshFails(t *testing.T) { +func TestPrepareAgentsReturnsOverlayRefreshError(t *testing.T) { dir := t.TempDir() statePath := filepath.Join(dir, ".sgai", "state.json") require.NoError(t, os.MkdirAll(filepath.Join(dir, ".sgai"), 0o755)) @@ -303,8 +376,8 @@ func TestRunAgentInterruptsWhenOverlayRefreshFails(t *testing.T) { r.paddedsgai = "test" r.previousAgent = "coordinator" - result := r.runAgent(context.Background(), "builder") - assert.Equal(t, resultInterrupt, result) + err := r.prepareAgents([]string{"builder"}) + require.Error(t, err) } func TestResolveRetrospectiveDirResuming(t *testing.T) { diff --git a/pkg/state/coordinator.go b/pkg/state/coordinator.go index fefe9c1..0488598 100644 --- a/pkg/state/coordinator.go +++ b/pkg/state/coordinator.go @@ -5,29 +5,32 @@ import ( "errors" "fmt" "log" + "slices" "strconv" "sync" "time" ) -const agentDoneWatchdogTimeout = time.Minute +// AgentDoneWatchdogTimeout bounds how long a finished agent process may linger. +const AgentDoneWatchdogTimeout = time.Minute -var errPendingHumanInput = errors.New("human input already pending") +type pendingHumanInput struct { + question *MultiChoiceQuestion + humanMessage string + agent string + promptToken string + responseCh chan string +} -// Coordinator manages workflow state in memory with blocking ask/answer delivery -// and a soft-stop watchdog for agent-done transitions. +// Coordinator manages workflow state in memory with blocking ask/answer delivery. type Coordinator struct { - mu sync.Mutex - wf Workflow - currentResponseCh chan string - currentPromptToken string - promptSeq uint64 - savePath string - saveWorkflow func(string, *Workflow) error - - doneOnce sync.Once - doneTimer *time.Timer - agentCancel context.CancelFunc + mu sync.Mutex + wf Workflow + currentPrompt *pendingHumanInput + pendingPrompts []*pendingHumanInput + promptSeq uint64 + savePath string + saveWorkflow func(string, *Workflow) error onUpdate func() } @@ -107,38 +110,50 @@ func (c *Coordinator) UpdateState(fn func(*Workflow)) error { return nil } -// AskAndWait stores the pending question in memory and blocks until Respond -// delivers an answer or ctx ends. -func (c *Coordinator) AskAndWait(ctx context.Context, question *MultiChoiceQuestion, humanMessage string) (string, error) { - responseCh := make(chan string, 1) +// ReplaceState replaces the coordinator workflow with a detached snapshot of wf +// and persists it to disk. +func (c *Coordinator) ReplaceState(wf *Workflow) error { + if wf == nil { + return errors.New("nil workflow") + } + + snapshot := wf.detached() c.mu.Lock() - if c.currentResponseCh != nil { + notify := c.onUpdate + if errSave := c.saveWorkflow(c.savePath, &snapshot); errSave != nil { c.mu.Unlock() - return "", errPendingHumanInput + return errSave } - c.currentResponseCh = responseCh - c.promptSeq++ - c.currentPromptToken = strconv.FormatUint(c.promptSeq, 10) + c.wf = snapshot c.mu.Unlock() - if err := c.UpdateState(func(wf *Workflow) { - wf.MultiChoiceQuestion = question - wf.HumanMessage = humanMessage - }); err != nil { - c.clearPendingQuestion(responseCh) - return "", fmt.Errorf("saving question state: %w", err) + if notify != nil { + notify() + } + return nil +} + +// AskAndWait stores the pending question in memory and blocks until Respond +// delivers an answer or ctx ends. +func (c *Coordinator) AskAndWait(ctx context.Context, question *MultiChoiceQuestion, humanMessage, askingAgent string) (string, error) { + prompt, isCurrent := c.enqueuePendingPrompt(question, humanMessage, askingAgent) + if isCurrent { + if err := c.persistPendingPrompt(prompt); err != nil { + c.advancePendingPrompt(prompt) + return "", fmt.Errorf("saving question state: %w", err) + } } log.Println("askandwait: blocking for human answer") select { - case answer := <-responseCh: + case answer := <-prompt.responseCh: log.Println("askandwait: answer received from human") - c.clearPendingQuestion(responseCh) + c.advancePendingPrompt(prompt) return answer, nil case <-ctx.Done(): log.Println("askandwait: context cancelled:", ctx.Err()) - c.clearPendingQuestion(responseCh) + c.advancePendingPrompt(prompt) return "", fmt.Errorf("waiting for human response: %w", ctx.Err()) } } @@ -152,10 +167,24 @@ func (c *Coordinator) Respond(answer string) bool { func (c *Coordinator) CurrentPromptToken() string { c.mu.Lock() defer c.mu.Unlock() - if c.currentResponseCh == nil { + if c.currentPrompt == nil { return "" } - return c.currentPromptToken + return c.currentPrompt.promptToken +} + +// CurrentHumanInput returns the current pending prompt token and visible +// human-input payload from one coordinator snapshot. +func (c *Coordinator) CurrentHumanInput() (promptToken, humanMessage, askingAgent string, question *MultiChoiceQuestion) { + c.mu.Lock() + defer c.mu.Unlock() + if c.currentPrompt != nil { + return c.currentPrompt.promptToken, c.currentPrompt.humanMessage, c.currentPrompt.agent, detachedMultiChoiceQuestion(c.currentPrompt.question) + } + if c.promptSeq == 0 && c.wf.NeedsHumanInput() { + return "", c.wf.HumanMessage, c.wf.HumanInputAgent, detachedMultiChoiceQuestion(c.wf.MultiChoiceQuestion) + } + return "", "", "", nil } // RespondIfCurrent delivers the answer only when promptToken matches the @@ -165,104 +194,102 @@ func (c *Coordinator) RespondIfCurrent(promptToken, answer string) bool { return c.respondIfCurrent(promptToken, answer) } -// SetAgentCancel stores the cancel function for the current agent run. -// It is called before each agent subprocess is launched so the watchdog can -// terminate that specific run if it hangs after setting status:agent-done. -func (c *Coordinator) SetAgentCancel(cancel context.CancelFunc) { - c.mu.Lock() - c.agentCancel = cancel - c.mu.Unlock() -} - -// AgentCancel returns the cancel function for the current agent run. -// Returns nil if none has been set or after ResetAgentDoneWatchdog clears it. -func (c *Coordinator) AgentCancel() context.CancelFunc { +func (c *Coordinator) respondIfCurrent(promptToken, answer string) bool { c.mu.Lock() defer c.mu.Unlock() - return c.agentCancel + return c.respondCurrentPromptLocked(c.currentPrompt, promptToken, answer) } -// ResetAgentDoneWatchdog prepares the watchdog for a fresh agent run. -// It stops any pending timer, clears the stored cancel function, and resets -// the sync.Once so the watchdog can fire again on the next agent-done. -func (c *Coordinator) ResetAgentDoneWatchdog() { - c.mu.Lock() - if c.doneTimer != nil { - c.doneTimer.Stop() - c.doneTimer = nil +func (c *Coordinator) respondCurrentPromptLocked(prompt *pendingHumanInput, promptToken, answer string) bool { + if prompt == nil { + log.Println("askandwait: no pending question, discarding response") + return false } - c.agentCancel = nil - c.doneOnce = sync.Once{} - c.mu.Unlock() -} -// StartAgentDoneWatchdog starts a one-minute timer that calls cancel once. -// Repeated calls are silently ignored via sync.Once. -func (c *Coordinator) StartAgentDoneWatchdog(cancel context.CancelFunc) { - if cancel == nil { - return + if c.currentPrompt != prompt { + log.Println("askandwait: prompt no longer current, discarding response") + return false } - c.mu.Lock() - c.doneOnce.Do(func() { - c.doneTimer = time.AfterFunc(agentDoneWatchdogTimeout, cancel) - }) - c.mu.Unlock() -} -// IsShuttingDown reports whether the agent-done watchdog has been started. -func (c *Coordinator) IsShuttingDown() bool { - c.mu.Lock() - defer c.mu.Unlock() - return c.doneTimer != nil + if promptToken != "" && promptToken != prompt.promptToken { + log.Println("askandwait: stale prompt token, discarding response") + return false + } + + select { + case prompt.responseCh <- answer: + log.Println("askandwait: response queued for delivery") + return true + default: + log.Println("askandwait: response channel full, response discarded") + return false + } } -// Stop cancels the watchdog timer if it is running. -func (c *Coordinator) Stop() { +func (c *Coordinator) enqueuePendingPrompt(question *MultiChoiceQuestion, humanMessage, askingAgent string) (*pendingHumanInput, bool) { + prompt := &pendingHumanInput{ + question: question, + humanMessage: humanMessage, + agent: askingAgent, + promptToken: "", + responseCh: make(chan string, 1), + } + c.mu.Lock() defer c.mu.Unlock() - if c.doneTimer != nil { - c.doneTimer.Stop() + c.promptSeq++ + prompt.promptToken = strconv.FormatUint(c.promptSeq, 10) + if c.currentPrompt == nil { + c.currentPrompt = prompt + return prompt, true } + c.pendingPrompts = append(c.pendingPrompts, prompt) + return prompt, false } -func (c *Coordinator) clearPendingQuestion(responseCh chan string) { - c.mu.Lock() - if c.currentResponseCh == responseCh { - c.currentResponseCh = nil - c.currentPromptToken = "" +func (c *Coordinator) advancePendingPrompt(prompt *pendingHumanInput) { + nextPrompt, updateVisiblePrompt := c.removePendingPrompt(prompt) + if !updateVisiblePrompt { + return } - c.mu.Unlock() - - if err := c.UpdateState(func(wf *Workflow) { - wf.MultiChoiceQuestion = nil - wf.HumanMessage = "" - }); err != nil { - log.Println("failed to clear pending human input:", err) + if err := c.persistPendingPrompt(nextPrompt); err != nil { + log.Println("failed to update pending human input:", err) } } -func (c *Coordinator) respondIfCurrent(promptToken, answer string) bool { +func (c *Coordinator) removePendingPrompt(prompt *pendingHumanInput) (*pendingHumanInput, bool) { c.mu.Lock() - responseCh := c.currentResponseCh - currentPromptToken := c.currentPromptToken - c.mu.Unlock() + defer c.mu.Unlock() - if responseCh == nil { - log.Println("askandwait: no pending question, discarding response") - return false + if c.currentPrompt == prompt { + if len(c.pendingPrompts) == 0 { + c.currentPrompt = nil + return nil, true + } + nextPrompt := c.pendingPrompts[0] + c.pendingPrompts = c.pendingPrompts[1:] + c.currentPrompt = nextPrompt + return nextPrompt, true } - if promptToken != "" && promptToken != currentPromptToken { - log.Println("askandwait: stale prompt token, discarding response") - return false + idx := slices.Index(c.pendingPrompts, prompt) + if idx == -1 { + return nil, false } + c.pendingPrompts = slices.Delete(c.pendingPrompts, idx, idx+1) + return nil, false +} - select { - case responseCh <- answer: - log.Println("askandwait: response queued for delivery") - return true - default: - log.Println("askandwait: response channel full, response discarded") - return false - } +func (c *Coordinator) persistPendingPrompt(prompt *pendingHumanInput) error { + return c.UpdateState(func(wf *Workflow) { + if prompt == nil { + wf.MultiChoiceQuestion = nil + wf.HumanMessage = "" + wf.HumanInputAgent = "" + return + } + wf.MultiChoiceQuestion = prompt.question + wf.HumanMessage = prompt.humanMessage + wf.HumanInputAgent = prompt.agent + }) } diff --git a/pkg/state/state.go b/pkg/state/state.go index 828aabe..d423aeb 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -3,7 +3,9 @@ package state import ( + "crypto/rand" "encoding/json" + "errors" "fmt" "maps" "os" @@ -130,21 +132,30 @@ type SessionCost struct { ByAgent []AgentCost `json:"byAgent"` } +// AgentExecutionState tracks the visible workflow status and task for one agent. +type AgentExecutionState struct { + Status string `json:"status,omitempty"` + Task string `json:"task,omitempty"` +} + // Workflow represents the complete workflow state for a sgai session. // It tracks progress, inter-agent messaging, and workflow status. type Workflow struct { - Status string `json:"status"` - Task string `json:"task"` - Progress []ProgressEntry `json:"progress"` - HumanMessage string `json:"humanMessage"` - MultiChoiceQuestion *MultiChoiceQuestion `json:"multiChoiceQuestion,omitempty"` - Messages []Message `json:"messages"` - VisitCounts map[string]int `json:"visitCounts,omitempty"` - CurrentAgent string `json:"currentAgent,omitempty"` - Todos []TodoItem `json:"todos,omitempty"` - ProjectTodos []TodoItem `json:"projectTodos,omitempty"` - AgentSequence []AgentSequenceEntry `json:"agentSequence,omitempty"` - SessionID string `json:"sessionId,omitempty"` + Status string `json:"status"` + Task string `json:"task"` + Progress []ProgressEntry `json:"progress"` + HumanMessage string `json:"humanMessage"` + HumanInputAgent string `json:"humanInputAgent,omitempty"` + MultiChoiceQuestion *MultiChoiceQuestion `json:"multiChoiceQuestion,omitempty"` + Messages []Message `json:"messages"` + VisitCounts map[string]int `json:"visitCounts,omitempty"` + CurrentAgent string `json:"currentAgent,omitempty"` + AgentStates map[string]AgentExecutionState `json:"agentStates,omitempty"` + Todos []TodoItem `json:"todos,omitempty"` + TodosByAgent map[string][]TodoItem `json:"todosByAgent,omitempty"` + ProjectTodos []TodoItem `json:"projectTodos,omitempty"` + AgentSequence []AgentSequenceEntry `json:"agentSequence,omitempty"` + SessionID string `json:"sessionId,omitempty"` Cost SessionCost `json:"cost"` @@ -176,7 +187,9 @@ func NewWorkflow() Workflow { wf.Progress = []ProgressEntry{} wf.Messages = []Message{} wf.VisitCounts = map[string]int{} + wf.AgentStates = map[string]AgentExecutionState{} wf.Todos = []TodoItem{} + wf.TodosByAgent = map[string][]TodoItem{} wf.ProjectTodos = []TodoItem{} wf.AgentSequence = []AgentSequenceEntry{} wf.Cost.ByAgent = []AgentCost{} @@ -197,7 +210,9 @@ func (w *Workflow) detached() Workflow { detached.Progress = slices.Clone(w.Progress) detached.Messages = slices.Clone(w.Messages) detached.VisitCounts = maps.Clone(w.VisitCounts) + detached.AgentStates = maps.Clone(w.AgentStates) detached.Todos = slices.Clone(w.Todos) + detached.TodosByAgent = detachedAgentTodos(w.TodosByAgent) detached.ProjectTodos = slices.Clone(w.ProjectTodos) detached.AgentSequence = slices.Clone(w.AgentSequence) detached.Cost.ByAgent = detachedAgentCosts(w.Cost.ByAgent) @@ -231,6 +246,17 @@ func detachedAgentCosts(agentCosts []AgentCost) []AgentCost { return detached } +func detachedAgentTodos(todosByAgent map[string][]TodoItem) map[string][]TodoItem { + if todosByAgent == nil { + return nil + } + detached := make(map[string][]TodoItem, len(todosByAgent)) + for agent, todos := range todosByAgent { + detached[agent] = slices.Clone(todos) + } + return detached +} + // Message represents an inter-agent message in the workflow system. type Message struct { ID int `json:"id"` @@ -255,10 +281,17 @@ func load(path string) (Workflow, error) { if wf.VisitCounts == nil { wf.VisitCounts = make(map[string]int) } + if wf.AgentStates == nil { + wf.AgentStates = make(map[string]AgentExecutionState) + } + if wf.TodosByAgent == nil { + wf.TodosByAgent = make(map[string][]TodoItem) + } if wf.ModelStatuses == nil { wf.ModelStatuses = make(map[string]string) } wf.HumanMessage = "" + wf.HumanInputAgent = "" wf.MultiChoiceQuestion = nil return wf, nil } @@ -266,6 +299,7 @@ func load(path string) (Workflow, error) { func save(path string, wf *Workflow) error { snapshot := wf.detached() snapshot.HumanMessage = "" + snapshot.HumanInputAgent = "" snapshot.MultiChoiceQuestion = nil if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { return fmt.Errorf("creating workflow state directory for %s: %w", path, err) @@ -274,8 +308,99 @@ func save(path string, wf *Workflow) error { if err != nil { return fmt.Errorf("encoding workflow state %s: %w", path, err) } - if errWrite := os.WriteFile(path, data, 0o644); errWrite != nil { + if errWrite := writeWorkflowStateAtomically(path, data); errWrite != nil { return fmt.Errorf("writing workflow state %s: %w", path, errWrite) } return nil } + +func writeWorkflowStateAtomically(path string, data []byte) error { + mode, exists, errStateMode := workflowStateMode(path) + if errStateMode != nil { + return errStateMode + } + + tempFile, errCreateTemp := createWorkflowStateTempFile(filepath.Dir(path), filepath.Base(path), exists) + if errCreateTemp != nil { + return errCreateTemp + } + + tempPath := tempFile.Name() + removeTemp := true + defer func() { + if removeTemp { + _ = os.Remove(tempPath) + } + }() + + if exists { + if errChmod := tempFile.Chmod(mode); errChmod != nil { + _ = tempFile.Close() + return fmt.Errorf("setting temporary file mode: %w", errChmod) + } + } + if _, errWrite := tempFile.Write(data); errWrite != nil { + _ = tempFile.Close() + return fmt.Errorf("writing temporary file: %w", errWrite) + } + if errSync := tempFile.Sync(); errSync != nil { + _ = tempFile.Close() + return fmt.Errorf("syncing temporary file: %w", errSync) + } + if errClose := tempFile.Close(); errClose != nil { + return fmt.Errorf("closing temporary file: %w", errClose) + } + if errRename := os.Rename(tempPath, path); errRename != nil { + return fmt.Errorf("renaming temporary file: %w", errRename) + } + + removeTemp = false + return nil +} + +func createWorkflowStateTempFile(dir, base string, preserveMode bool) (*os.File, error) { + if preserveMode { + tempFile, errCreate := os.CreateTemp(dir, base+".tmp-*") + if errCreate != nil { + return nil, fmt.Errorf("creating temporary file: %w", errCreate) + } + return tempFile, nil + } + return openWorkflowStateTempFile(dir, base+".tmp-") +} + +func openWorkflowStateTempFile(dir, prefix string) (*os.File, error) { + suffix := make([]byte, 8) + const chars = "abcdefghijklmnopqrstuvwxyz0123456789" + for range 100 { + if _, errRead := rand.Read(suffix); errRead != nil { + return nil, fmt.Errorf("creating temporary file: %w", errRead) + } + for i := range suffix { + suffix[i] = chars[int(suffix[i])%len(chars)] + } + tempPath := filepath.Join(dir, prefix+string(suffix)) + tempFile, errOpen := os.OpenFile(tempPath, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o644) + if errOpen == nil { + return tempFile, nil + } + if !errors.Is(errOpen, os.ErrExist) { + return nil, fmt.Errorf("creating temporary file: %w", errOpen) + } + } + return nil, errors.New("creating temporary file: exhausted attempts") +} + +func workflowStateMode(path string) (mode os.FileMode, exists bool, err error) { + info, errStat := os.Stat(path) + if errStat == nil { + if info.IsDir() { + return 0, false, fmt.Errorf("workflow state path is a directory: %s", path) + } + return info.Mode().Perm(), true, nil + } + if errors.Is(errStat, os.ErrNotExist) { + return 0o644, false, nil + } + return 0, false, fmt.Errorf("stat existing workflow state: %w", errStat) +} diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index d41d138..0081900 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -2,9 +2,16 @@ package state import ( "context" + "errors" + "os" + "os/exec" + "os/signal" "path/filepath" + "strings" "sync" + "syscall" "testing" + "testing/synctest" "time" "github.com/stretchr/testify/assert" @@ -87,6 +94,14 @@ func testWorkflowWithReferenceFields() Workflow { return wf } +func testLargeWorkflowReplacement() Workflow { + wf := testWorkflowWithReferenceFields() + wf.Task = "replacement task" + wf.Progress = append(wf.Progress, testProgressEntry("replacement progress")) + wf.Summary = strings.Repeat("replacement summary ", 256) + return wf +} + func mutateWorkflowReferenceFields(wf *Workflow) { wf.Progress[0].Description = "mutated progress" wf.Messages[0].Body = "mutated message" @@ -102,6 +117,41 @@ func mutateWorkflowReferenceFields(wf *Workflow) { wf.MultiChoiceQuestion.Questions[0].Choices[0] = "maybe" } +func TestCoordinatorReplaceStateFileSizeLimitHelper(t *testing.T) { + statePath := os.Getenv("SGAI_STATE_REPLACE_HELPER_PATH") + if statePath == "" { + return + } + + coord, err := NewCoordinator(statePath) + require.NoError(t, err) + + signal.Ignore(syscall.SIGXFSZ) + require.NoError(t, syscall.Setrlimit(syscall.RLIMIT_FSIZE, &syscall.Rlimit{Cur: 1024, Max: 1024})) + + replacement := testLargeWorkflowReplacement() + err = coord.ReplaceState(&replacement) + require.Error(t, err) +} + +func TestSaveWorkflowStateFileModeHelper(t *testing.T) { + statePath := os.Getenv("SGAI_STATE_MODE_HELPER_PATH") + if statePath == "" { + return + } + + oldMask := syscall.Umask(0o027) + defer syscall.Umask(oldMask) + + workflow := NewWorkflow() + errSave := save(statePath, &workflow) + require.NoError(t, errSave) + + info, errStat := os.Stat(statePath) + require.NoError(t, errStat) + assert.Equal(t, os.FileMode(0o640), info.Mode().Perm()) +} + func channelClosed(ch <-chan struct{}) bool { select { case <-ch: @@ -297,6 +347,7 @@ func TestQuestionStateIsMemoryOnly(t *testing.T) { question := testMultiChoiceQuestion("yes", "no") wf := NewWorkflow() wf.HumanMessage = "Please respond" + wf.HumanInputAgent = "builder" wf.MultiChoiceQuestion = question coord, err := NewCoordinatorWith(statePath, wf) @@ -304,66 +355,106 @@ func TestQuestionStateIsMemoryOnly(t *testing.T) { snapshot := coord.State() assert.Equal(t, "Please respond", snapshot.HumanMessage) + assert.Equal(t, "builder", snapshot.HumanInputAgent) assert.NotNil(t, snapshot.MultiChoiceQuestion) loaded, err := load(statePath) require.NoError(t, err) assert.Empty(t, loaded.HumanMessage) + assert.Empty(t, loaded.HumanInputAgent) assert.Nil(t, loaded.MultiChoiceQuestion) } func TestCurrentPromptTokenChangesAcrossQuestions(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + dir := t.TempDir() + coord := NewCoordinatorEmpty(filepath.Join(dir, "state.json")) + + firstCtx, cancelFirst := context.WithCancel(context.Background()) + defer cancelFirst() + errFirstCh := make(chan error, 1) + go func() { + _, err := coord.AskAndWait(firstCtx, nil, "same question", "builder") + errFirstCh <- err + }() + + synctest.Wait() + firstToken := coord.CurrentPromptToken() + require.NotEmpty(t, firstToken) + + cancelFirst() + synctest.Wait() + require.ErrorIs(t, <-errFirstCh, context.Canceled) + assert.Empty(t, coord.CurrentPromptToken()) + + secondCtx, cancelSecond := context.WithCancel(context.Background()) + defer cancelSecond() + secondAnswerCh := make(chan string, 1) + errSecondCh := make(chan error, 1) + go func() { + answer, err := coord.AskAndWait(secondCtx, nil, "same question", "builder") + if err != nil { + errSecondCh <- err + return + } + secondAnswerCh <- answer + }() + + synctest.Wait() + secondToken := coord.CurrentPromptToken() + require.NotEmpty(t, secondToken) + assert.NotEqual(t, firstToken, secondToken) + assert.False(t, coord.RespondIfCurrent(firstToken, "stale answer")) + require.True(t, coord.RespondIfCurrent(secondToken, "current answer")) + + synctest.Wait() + select { + case answer := <-secondAnswerCh: + require.Equal(t, "current answer", answer) + case err := <-errSecondCh: + require.NoError(t, err) + default: + t.Fatal("second prompt did not complete") + } + assert.Empty(t, coord.CurrentPromptToken()) + }) +} + +func TestRespondCurrentPromptLockedRejectsPromptAfterPromotion(t *testing.T) { dir := t.TempDir() coord := NewCoordinatorEmpty(filepath.Join(dir, "state.json")) - firstCtx, cancelFirst := context.WithCancel(context.Background()) - defer cancelFirst() - firstErrCh := make(chan error, 1) - go func() { - _, err := coord.AskAndWait(firstCtx, nil, "same question") - firstErrCh <- err - }() - - firstToken := waitForCurrentPromptToken(t, coord) - cancelFirst() - require.ErrorIs(t, <-firstErrCh, context.Canceled) - require.Eventually(t, func() bool { - return coord.CurrentPromptToken() == "" - }, time.Second, 10*time.Millisecond) - - secondCtx := t.Context() - secondAnswerCh := make(chan string, 1) - secondErrCh := make(chan error, 1) - go func() { - answer, err := coord.AskAndWait(secondCtx, nil, "same question") - if err != nil { - secondErrCh <- err - return - } - secondAnswerCh <- answer - }() - - secondToken := waitForCurrentPromptToken(t, coord) - assert.NotEqual(t, firstToken, secondToken) - assert.False(t, coord.RespondIfCurrent(firstToken, "stale answer")) - assert.True(t, coord.RespondIfCurrent(secondToken, "current answer")) - assert.Equal(t, "current answer", <-secondAnswerCh) + firstPrompt, isCurrent := coord.enqueuePendingPrompt(nil, "first message", "go-developer") + require.True(t, isCurrent) + secondPrompt, isCurrent := coord.enqueuePendingPrompt(nil, "second message", "react-developer") + require.False(t, isCurrent) + + nextPrompt, updateVisiblePrompt := coord.removePendingPrompt(firstPrompt) + require.True(t, updateVisiblePrompt) + require.Same(t, secondPrompt, nextPrompt) + + coord.mu.Lock() + staleAccepted := coord.respondCurrentPromptLocked(firstPrompt, firstPrompt.promptToken, "stale answer") + coord.mu.Unlock() + assert.False(t, staleAccepted) + select { - case err := <-secondErrCh: - require.NoError(t, err) + case answer := <-firstPrompt.responseCh: + t.Fatalf("expected promoted prompt to reject stale answer, got %q", answer) default: } - require.Eventually(t, func() bool { - return coord.CurrentPromptToken() == "" - }, time.Second, 10*time.Millisecond) -} -func waitForCurrentPromptToken(t *testing.T, coord *Coordinator) string { - t.Helper() - require.Eventually(t, func() bool { - return coord.CurrentPromptToken() != "" - }, time.Second, 10*time.Millisecond) - return coord.CurrentPromptToken() + coord.mu.Lock() + currentAccepted := coord.respondCurrentPromptLocked(secondPrompt, secondPrompt.promptToken, "current answer") + coord.mu.Unlock() + require.True(t, currentAccepted) + + select { + case answer := <-secondPrompt.responseCh: + require.Equal(t, "current answer", answer) + default: + t.Fatal("expected current prompt to accept response") + } } func TestNewCoordinator(t *testing.T) { @@ -600,3 +691,89 @@ func TestCoordinatorUpdateState(t *testing.T) { assert.Equal(t, "newer progress", loaded.Progress[0].Description) }) } + +func TestCoordinatorReplaceState(t *testing.T) { + t.Run("earlySaveFailureKeepsPreviousState", func(t *testing.T) { + dir := t.TempDir() + statePath := filepath.Join(dir, "state.json") + + initial := testWorkflowWithReferenceFields() + coord, err := NewCoordinatorWith(statePath, initial) + require.NoError(t, err) + + loadedBefore, errLoad := load(statePath) + require.NoError(t, errLoad) + + replacement := testWorkflowWithReferenceFields() + replacement.Task = "replacement task" + replacement.Progress = append(replacement.Progress, testProgressEntry("replacement progress")) + + errSave := errors.New("save failed") + coord.saveWorkflow = func(_ string, _ *Workflow) error { + return errSave + } + + err = coord.ReplaceState(&replacement) + require.ErrorIs(t, err, errSave) + + assert.Equal(t, initial, coord.State()) + + loaded, errLoad := load(statePath) + require.NoError(t, errLoad) + assert.Equal(t, loadedBefore, loaded) + }) + + t.Run("realSaveFailureKeepsPreviousStateOnDisk", func(t *testing.T) { + dir := t.TempDir() + statePath := filepath.Join(dir, "state.json") + + initial := testWorkflowWithReferenceFields() + require.NoError(t, save(statePath, &initial)) + + loadedBefore, errLoad := load(statePath) + require.NoError(t, errLoad) + + cmd := exec.Command(os.Args[0], "-test.run=^TestCoordinatorReplaceStateFileSizeLimitHelper$") + cmd.Env = append(os.Environ(), "SGAI_STATE_REPLACE_HELPER_PATH="+statePath) + + output, errRun := cmd.CombinedOutput() + require.NoError(t, errRun, string(output)) + + loaded, errLoad := load(statePath) + require.NoError(t, errLoad) + assert.Equal(t, loadedBefore, loaded) + }) +} + +func TestSaveWorkflowStateFileModes(t *testing.T) { + t.Run("newFileRespectsUmask", func(t *testing.T) { + dir := t.TempDir() + statePath := filepath.Join(dir, "state.json") + + cmd := exec.Command(os.Args[0], "-test.run=^TestSaveWorkflowStateFileModeHelper$") + cmd.Env = append(os.Environ(), "SGAI_STATE_MODE_HELPER_PATH="+statePath) + + output, errRun := cmd.CombinedOutput() + require.NoError(t, errRun, string(output)) + + info, errStat := os.Stat(statePath) + require.NoError(t, errStat) + assert.Equal(t, os.FileMode(0o640), info.Mode().Perm()) + }) + + t.Run("existingFilePreservesMode", func(t *testing.T) { + dir := t.TempDir() + statePath := filepath.Join(dir, "state.json") + + errWrite := os.WriteFile(statePath, []byte("{}"), 0o600) + require.NoError(t, errWrite) + + workflow := NewWorkflow() + errSave := save(statePath, &workflow) + require.NoError(t, errSave) + + info, errStat := os.Stat(statePath) + require.NoError(t, errStat) + assert.Equal(t, os.FileMode(0o600), info.Mode().Perm()) + }) +}