diff --git a/cmd/talk.go b/cmd/talk.go index 3e14ee8..0f03a1a 100644 --- a/cmd/talk.go +++ b/cmd/talk.go @@ -8,6 +8,7 @@ import ( "os/signal" "strings" "syscall" + "time" "github.com/bgdnvk/clanker/internal/clankercloud" "github.com/bgdnvk/clanker/internal/claudecode" @@ -16,6 +17,16 @@ import ( "github.com/spf13/viper" ) +// Bridge crash recovery (clanker-cli #21). The Python bridge can die +// mid-session — bad Python deps, an unhandled exception inside a tool +// call, OOM. Before #21 the REPL just kept printing "bridge process +// exited" on every subsequent prompt. Now we restart it transparently, +// capped so a permanently-broken bridge doesn't burn CPU in a loop. +const ( + maxBridgeRestarts = 3 + bridgeRestartWin = time.Minute +) + var talkCmd = &cobra.Command{ Use: "talk", Short: "Interactive conversation with an AI agent", @@ -53,16 +64,17 @@ func runHermesTalk(parentCtx context.Context, debug bool) error { return fmt.Errorf("hermes agent not found: %w\nRun 'make setup-hermes' to install", err) } - runner := hermes.NewRunner(hermesPath, debug) - runner.SetEnv(buildHermesEnv()) - ctx, cancel := context.WithCancel(parentCtx) defer cancel() + runner := hermes.NewRunner(hermesPath, debug) + runner.SetEnv(buildHermesEnv()) if err := runner.Start(ctx); err != nil { return fmt.Errorf("failed to start hermes agent: %w", err) } - defer runner.Stop() + // Use a pointer-pointer so the deferred Stop() picks up the + // most-recent runner after a restart cycle. + defer func() { runner.Stop() }() // Handle signals: Ctrl+C interrupts the current response but does not // kill the session. A second Ctrl+C exits. @@ -79,6 +91,7 @@ func runHermesTalk(parentCtx context.Context, debug bool) error { fmt.Println("Type 'exit' or 'quit' to end the session.") fmt.Println() + restartTimes := make([]time.Time, 0, maxBridgeRestarts) scanner := bufio.NewScanner(os.Stdin) for { fmt.Print("you> ") @@ -108,34 +121,38 @@ func runHermesTalk(parentCtx context.Context, debug bool) error { } } - events, err := runner.Prompt(ctx, input) - if err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - continue - } - - fmt.Print("hermes> ") - hadDelta := false - for event := range events { - switch { - case event.Error != nil: - fmt.Fprintf(os.Stderr, "\nError: %v\n", event.Error) - case event.MessageDelta != nil: - fmt.Print(event.MessageDelta.Text) - hadDelta = true - case event.ToolCall != nil: - if debug { - fmt.Fprintf(os.Stderr, "\n[tool: %s]\n", event.ToolCall.Name) - } - case event.Thought != nil: - if debug { - fmt.Fprintf(os.Stderr, "\n[thinking: %s]\n", event.Thought.Text) - } - case event.Final != nil: - if !hadDelta && event.Final.Text != "" { - fmt.Print(event.Final.Text) + // Inner retry loop: if the bridge dies mid-prompt, restart + // it once and re-issue the same prompt so the user doesn't + // have to retype. + for attempt := 0; attempt < 2; attempt++ { + bridgeExit, err := streamHermesPrompt(ctx, runner, input, debug) + if err != nil { + if hermes.IsBridgeExitError(err) { + bridgeExit = true + } else { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + break } } + if !bridgeExit { + break + } + if attempt == 1 { + fmt.Fprintln(os.Stderr, "\nHermes bridge died again after restart — giving up on this turn.") + break + } + if !canRestartBridge(&restartTimes) { + fmt.Fprintf(os.Stderr, "\nHermes bridge crashed %d times in the last minute — refusing to restart again. Type 'exit' and rerun 'clanker talk'.\n", maxBridgeRestarts) + return fmt.Errorf("hermes bridge crashed too many times") + } + fmt.Fprintln(os.Stderr, "\nHermes bridge died — restarting...") + runner.Stop() + runner = hermes.NewRunner(hermesPath, debug) + runner.SetEnv(buildHermesEnv()) + if err := runner.Start(ctx); err != nil { + fmt.Fprintf(os.Stderr, "Failed to restart hermes bridge: %v\n", err) + return fmt.Errorf("hermes bridge restart failed: %w", err) + } } fmt.Println() fmt.Println() @@ -144,6 +161,71 @@ func runHermesTalk(parentCtx context.Context, debug bool) error { return nil } +// streamHermesPrompt runs one prompt turn. Returns (bridgeExited, err). +// bridgeExited is true when the events channel closed because the +// bridge died (so the caller knows to restart). err is set for any +// non-bridge-death problem; bridge-death errors flow through the +// boolean instead so the caller can act on them uniformly. +func streamHermesPrompt(ctx context.Context, runner *hermes.Runner, input string, debug bool) (bool, error) { + events, err := runner.Prompt(ctx, input) + if err != nil { + if hermes.IsBridgeExitError(err) { + return true, nil + } + return false, err + } + + fmt.Print("hermes> ") + hadDelta := false + bridgeExit := false + for event := range events { + switch { + case event.Error != nil: + if hermes.IsBridgeExitError(event.Error) { + bridgeExit = true + continue + } + fmt.Fprintf(os.Stderr, "\nError: %v\n", event.Error) + case event.MessageDelta != nil: + fmt.Print(event.MessageDelta.Text) + hadDelta = true + case event.ToolCall != nil: + if debug { + fmt.Fprintf(os.Stderr, "\n[tool: %s]\n", event.ToolCall.Name) + } + case event.Thought != nil: + if debug { + fmt.Fprintf(os.Stderr, "\n[thinking: %s]\n", event.Thought.Text) + } + case event.Final != nil: + if !hadDelta && event.Final.Text != "" { + fmt.Print(event.Final.Text) + } + } + } + return bridgeExit, nil +} + +// canRestartBridge implements a sliding-window rate limit. Returns +// true (and appends a fresh timestamp) when a restart is permitted, +// false when the configured ceiling has been hit inside the window. +func canRestartBridge(history *[]time.Time) bool { + now := time.Now() + cutoff := now.Add(-bridgeRestartWin) + kept := (*history)[:0] + for _, t := range *history { + if t.After(cutoff) { + kept = append(kept, t) + } + } + *history = kept + if len(*history) >= maxBridgeRestarts { + return false + } + *history = append(*history, now) + return true +} + func handleClankerCloudTalk(ctx context.Context, question string, debug bool) (bool, error) { client := clankercloud.NewClient() result, err := client.AskAgent(ctx, question, "") diff --git a/internal/hermes/ringbuffer.go b/internal/hermes/ringbuffer.go new file mode 100644 index 0000000..e38c796 --- /dev/null +++ b/internal/hermes/ringbuffer.go @@ -0,0 +1,40 @@ +package hermes + +import "sync" + +// ringBuffer is a small fixed-capacity byte buffer that keeps the most +// recent N bytes written to it. Used to capture the tail of bridge +// stderr so the restart path (clanker-cli #21) can include "what the +// bridge said before it died" in its error message — typically a +// Python traceback ending in something useful like ModuleNotFoundError. +type ringBuffer struct { + mu sync.Mutex + buf []byte + cap int +} + +func newRingBuffer(capacity int) *ringBuffer { + if capacity <= 0 { + capacity = 1 + } + return &ringBuffer{cap: capacity} +} + +// Write implements io.Writer. Keeps only the trailing `cap` bytes — +// older content is discarded as new content arrives. +func (r *ringBuffer) Write(p []byte) (int, error) { + r.mu.Lock() + defer r.mu.Unlock() + r.buf = append(r.buf, p...) + if len(r.buf) > r.cap { + r.buf = r.buf[len(r.buf)-r.cap:] + } + return len(p), nil +} + +// String returns the current trailing contents as a string. +func (r *ringBuffer) String() string { + r.mu.Lock() + defer r.mu.Unlock() + return string(r.buf) +} diff --git a/internal/hermes/runner.go b/internal/hermes/runner.go index 70ac258..a5e8886 100644 --- a/internal/hermes/runner.go +++ b/internal/hermes/runner.go @@ -2,16 +2,20 @@ package hermes import ( "bufio" + "bytes" "context" _ "embed" "encoding/json" + "errors" "fmt" "io" "os" "os/exec" "path/filepath" + "runtime/debug" "strings" "sync" + "sync/atomic" "time" "github.com/spf13/viper" @@ -20,8 +24,28 @@ import ( //go:embed bridge.py var embeddedBridgeScript []byte +// ErrBridgeExited is returned by Prompt / PromptSync when the bridge +// subprocess died mid-call. cmd/talk uses this to drive its automatic +// restart loop — see clanker-cli #21. +var ErrBridgeExited = errors.New("hermes bridge process exited") + // Runner manages the lifecycle of a Hermes bridge subprocess and provides // methods to send prompts and receive streaming events. +// +// Concurrency model (clanker-cli #20): +// - A single dispatcher goroutine owns r.scanner — no other goroutine +// touches it. Responses are routed by ID via the `inbox` map and +// notifications go to the currently-active prompt's `notifSink`. +// - Prompts are serialised by `promptMu` because the bridge protocol +// does NOT tag notifications with a request ID; the dispatcher +// can't demux them across overlapping prompts. Today's callers +// (cmd/talk REPL, cmd/ask PromptSync) are already serial. The +// mutex pins that invariant explicitly so a future caller that +// fires concurrent prompts blocks safely instead of corrupting +// state. +// - `running` is atomic.Bool — read by IsRunning, written by Start +// and Stop. No data race even though Stop runs on a different +// goroutine than the dispatcher's exit path. type Runner struct { cmd *exec.Cmd stdin io.WriteCloser @@ -30,10 +54,42 @@ type Runner struct { hermesPath string env []string sessionID string - nextID int - mu sync.Mutex - running bool bridgeFile string // temp file path if we wrote the embedded bridge + + mu sync.Mutex // guards: nextID, inbox, notifSink, dispatchErr + nextID int + + // inbox routes ID-matched responses (call() + Prompt completion) + // from the dispatcher to whichever goroutine is awaiting that ID. + // Each Prompt/call allocates an ID, registers a 1-slot channel, and + // deregisters on completion. + inbox map[int]chan *Response + + // notifSink receives notifications (responses with Method != "" + // and no ID). Set under r.mu by a Prompt call before sending the + // request and cleared on completion. Serialised by promptMu so at + // most one stream-consumer exists at a time. + notifSink chan<- *Response + + // promptMu serialises Prompt calls. The bridge protocol does not + // tag notifications with a request ID — overlapping prompts + // cannot be demuxed. Hold this for the entire lifetime of a + // Prompt call (request → final response). + promptMu sync.Mutex + + // dispatchErr is set by the dispatcher when it exits (scanner + // EOF, read error, or panic). Subsequent callers see it via + // shutdownReason() so they can surface a useful message. + dispatchErr error + dispatchDone chan struct{} + + running atomic.Bool + + // stderrTail captures the last few KB of bridge stderr so the + // "Clanker talk restart" path (#21) can include the actual error + // (e.g. "ModuleNotFoundError: pydantic_ai") in its message instead + // of just saying "bridge exited." + stderrTail *ringBuffer } // FindHermesPath locates the hermes-agent vendor directory. @@ -84,6 +140,8 @@ func NewRunner(hermesPath string, debug bool) *Runner { hermesPath: hermesPath, debug: debug, sessionID: fmt.Sprintf("clanker-%d", time.Now().UnixNano()), + inbox: make(map[int]chan *Response), + nextID: 1, } } @@ -143,26 +201,41 @@ func (r *Runner) Start(ctx context.Context) error { return fmt.Errorf("failed to get stdout pipe: %w", err) } - // Send stderr to our stderr so the user sees bridge diagnostics. - r.cmd.Stderr = os.Stderr - - // Use a large scanner buffer for big JSON responses. - r.scanner = bufio.NewScanner(stdout) - buf := make([]byte, 0, 64*1024) - r.scanner.Buffer(buf, 512*1024) + // Tee stderr to a small ring buffer for the restart-error message + // path (clanker-cli #21) while still forwarding to the user's + // terminal for real-time diagnostics. + r.stderrTail = newRingBuffer(4 * 1024) + stderrPipe, err := r.cmd.StderrPipe() + if err != nil { + return fmt.Errorf("failed to get stderr pipe: %w", err) + } if err := r.cmd.Start(); err != nil { return fmt.Errorf("failed to start bridge process: %w", err) } - r.running = true + + // Start stderr drainer. Tees bridge stderr to os.Stderr (live diag) + // AND a small ring buffer (#21 restart messages). + go func() { + _, _ = io.Copy(io.MultiWriter(os.Stderr, r.stderrTail), stderrPipe) + }() + + // Wire scanner + dispatcher BEFORE the handshake. The dispatcher + // has been the sole reader of r.scanner since clanker-cli #20. + if err := r.startWithStreams(stdout); err != nil { + r.Stop() + return err + } + + r.running.Store(true) if r.debug { fmt.Fprintf(os.Stderr, "[hermes] bridge started (pid %d)\n", r.cmd.Process.Pid) } - // Send initialize handshake. - r.nextID = 1 - resp, err := r.call("initialize", nil) + // Send initialize handshake — uses call() which now goes through + // the dispatcher's inbox routing. + resp, err := r.call(ctx, "initialize", nil) if err != nil { r.Stop() return fmt.Errorf("initialize handshake failed: %w", err) @@ -181,150 +254,347 @@ func (r *Runner) Start(ctx context.Context) error { return nil } -// call sends a request and reads a single response (no streaming). -func (r *Runner) call(method string, params interface{}) (*Response, error) { - r.mu.Lock() - id := r.nextID - r.nextID++ - r.mu.Unlock() - - req := NewRequest(id, method, params) - data, err := json.Marshal(req) - if err != nil { - return nil, fmt.Errorf("marshal request: %w", err) - } - - r.mu.Lock() - _, err = r.stdin.Write(append(data, '\n')) - r.mu.Unlock() - if err != nil { - return nil, fmt.Errorf("write to bridge stdin: %w", err) - } +// startWithStreams wires the scanner + dispatcher around a caller-supplied +// stdout stream. Split out from Start() so tests can drive a mock bridge +// over an io.Pipe without spawning a Python process. +func (r *Runner) startWithStreams(stdout io.Reader) error { + r.scanner = bufio.NewScanner(stdout) + buf := make([]byte, 0, 64*1024) + r.scanner.Buffer(buf, 512*1024) + r.dispatchDone = make(chan struct{}) + r.dispatchErr = nil + go r.dispatch() + return nil +} - // Read lines until we get a response with our ID. - for { - if !r.scanner.Scan() { - if err := r.scanner.Err(); err != nil { - return nil, fmt.Errorf("bridge stdout read error: %w", err) +// dispatch is the single goroutine that reads r.scanner. It owns the +// scanner exclusively — no other goroutine touches it. Responses are +// routed to the registered inbox channel by ID; notifications go to +// the currently-active prompt's notifSink. On EOF / read error / panic +// it closes every pending inbox channel with ErrBridgeExited so callers +// don't block forever. +func (r *Runner) dispatch() { + defer func() { + if p := recover(); p != nil { + r.mu.Lock() + if r.dispatchErr == nil { + r.dispatchErr = fmt.Errorf("hermes dispatcher panic: %v", p) + } + r.mu.Unlock() + if r.debug { + fmt.Fprintf(os.Stderr, "[hermes] dispatcher panic: %v\n%s\n", p, debug.Stack()) } - return nil, fmt.Errorf("bridge process exited unexpectedly") + } + r.shutdownInbox() + close(r.dispatchDone) + }() + + for r.scanner.Scan() { + line := r.scanner.Bytes() + if len(bytes.TrimSpace(line)) == 0 { + continue } var resp Response - if err := json.Unmarshal(r.scanner.Bytes(), &resp); err != nil { + if err := json.Unmarshal(line, &resp); err != nil { if r.debug { fmt.Fprintf(os.Stderr, "[hermes] skipping unparseable line: %s\n", r.scanner.Text()) } continue } - if resp.ID == id { - if resp.Error != nil { - return nil, resp.Error + // ID > 0 → response to a request. Route to the registered channel. + if resp.ID > 0 { + r.mu.Lock() + ch, ok := r.inbox[resp.ID] + if ok { + delete(r.inbox, resp.ID) + } + r.mu.Unlock() + if ok { + // Non-blocking send: inbox channels are 1-buffered and + // only the awaiting goroutine reads from them, so this + // never blocks. Use select-default as defence in depth + // in case a caller abandoned (ctx-cancelled). + select { + case ch <- &resp: + default: + } + } + continue + } + + // No ID → notification. Route to the currently-active prompt's + // sink, if any. Drop silently if no prompt is in flight (or if + // the sink buffer is full; better to drop a delta than block + // the dispatcher). + if resp.Method != "" { + r.mu.Lock() + sink := r.notifSink + r.mu.Unlock() + if sink != nil { + select { + case sink <- &resp: + default: + if r.debug { + fmt.Fprintf(os.Stderr, "[hermes] dropped notification (sink full): %s\n", resp.Method) + } + } } - return &resp, nil } - // Skip notifications during the init handshake. } + + // Scanner exited. Record the reason and fan it out to every + // pending caller via shutdownInbox in the deferred path above. + r.mu.Lock() + if r.dispatchErr == nil { + if err := r.scanner.Err(); err != nil { + r.dispatchErr = fmt.Errorf("%w: %v", ErrBridgeExited, err) + } else { + r.dispatchErr = ErrBridgeExited + } + } + r.mu.Unlock() } -// Prompt sends a user prompt and returns a channel of streaming Events. -// The channel is closed after the final result event or an error. -func (r *Runner) Prompt(ctx context.Context, text string) (<-chan Event, error) { +// shutdownInbox closes every registered inbox channel with the recorded +// dispatchErr so callers blocked in call() / Prompt unblock instead of +// hanging forever. notifSink is also closed. +func (r *Runner) shutdownInbox() { + r.mu.Lock() + defer r.mu.Unlock() + // Closing the inbox channel signals the awaiter that the bridge died. + // Receivers detect this via the second return value of `<-ch`. + for id, ch := range r.inbox { + close(ch) + delete(r.inbox, id) + } + if r.notifSink != nil { + // Don't close — receiver does that. Just drop the reference so + // new sends from any latecomer hit the nil-check above. + r.notifSink = nil + } +} + +// shutdownReason returns the dispatcher's recorded exit error, or a +// generic "bridge process exited" if none was set. Used by Prompt / +// call when their inbox channel closes before yielding a response. +func (r *Runner) shutdownReason() error { + r.mu.Lock() + defer r.mu.Unlock() + if r.dispatchErr != nil { + // Augment with the stderr tail so the talk restart path (#21) + // gets useful "ModuleNotFoundError: pydantic_ai" diagnostics + // instead of just "bridge process exited". + if r.stderrTail != nil { + tail := strings.TrimSpace(r.stderrTail.String()) + if tail != "" { + return fmt.Errorf("%w; stderr tail: %s", r.dispatchErr, lastLines(tail, 3)) + } + } + return r.dispatchErr + } + return ErrBridgeExited +} + +// registerInbox allocates the next request ID and registers a 1-slot +// channel for the dispatcher to deliver the matching response to. The +// returned channel is closed by the dispatcher on bridge death so the +// caller's `<-ch` unblocks. Caller MUST defer-delete by ID to avoid +// leaking entries when ctx cancels mid-call. +func (r *Runner) registerInbox() (int, chan *Response) { r.mu.Lock() + defer r.mu.Unlock() id := r.nextID r.nextID++ - r.mu.Unlock() + ch := make(chan *Response, 1) + r.inbox[id] = ch + return id, ch +} - req := NewRequest(id, "prompt", &PromptParams{ - Text: text, - SessionID: r.sessionID, - }) +func (r *Runner) clearInbox(id int) { + r.mu.Lock() + defer r.mu.Unlock() + delete(r.inbox, id) +} + +// writeRequest serialises a JSON-RPC request to the bridge stdin under +// the mu lock so concurrent writers can't interleave bytes. +func (r *Runner) writeRequest(req *Request) error { data, err := json.Marshal(req) if err != nil { - return nil, fmt.Errorf("marshal prompt request: %w", err) + return fmt.Errorf("marshal request: %w", err) + } + r.mu.Lock() + defer r.mu.Unlock() + if r.stdin == nil { + return ErrBridgeExited + } + if _, err := r.stdin.Write(append(data, '\n')); err != nil { + return fmt.Errorf("write to bridge stdin: %w", err) } + return nil +} + +// call sends a request and waits for the response with the matching ID. +func (r *Runner) call(ctx context.Context, method string, params any) (*Response, error) { + id, ch := r.registerInbox() + defer r.clearInbox(id) + + req := NewRequest(id, method, params) + if err := r.writeRequest(req); err != nil { + return nil, err + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case resp, ok := <-ch: + if !ok { + return nil, r.shutdownReason() + } + if resp.Error != nil { + return nil, resp.Error + } + return resp, nil + } +} +// Prompt sends a user prompt and returns a channel of streaming Events. +// Serialised by promptMu — only one Prompt may be in flight per Runner +// at a time because the bridge protocol doesn't tag notifications with +// a request ID. Today's callers are all sequential; this lock pins +// that invariant so a future concurrent caller blocks instead of +// corrupting state. +func (r *Runner) Prompt(ctx context.Context, text string) (<-chan Event, error) { + // Hold the prompt lock for the lifetime of the call. The goroutine + // below releases it when it exits (final response or error). + r.promptMu.Lock() + + id, inboxCh := r.registerInbox() + notif := make(chan *Response, 64) + + // Install ourselves as the current notification sink. The dispatcher + // fans notifications here until we clear it on exit. r.mu.Lock() - _, err = r.stdin.Write(append(data, '\n')) + r.notifSink = notif r.mu.Unlock() - if err != nil { + + req := NewRequest(id, "prompt", &PromptParams{ + Text: text, + SessionID: r.sessionID, + }) + if err := r.writeRequest(req); err != nil { + r.clearInbox(id) + r.mu.Lock() + r.notifSink = nil + r.mu.Unlock() + r.promptMu.Unlock() return nil, fmt.Errorf("write prompt to bridge: %w", err) } - ch := make(chan Event, 64) - + out := make(chan Event, 64) go func() { - defer close(ch) + defer func() { + if p := recover(); p != nil { + if r.debug { + fmt.Fprintf(os.Stderr, "[hermes] prompt goroutine panic: %v\n%s\n", p, debug.Stack()) + } + select { + case out <- Event{Error: fmt.Errorf("hermes prompt panic: %v", p)}: + default: + } + } + r.clearInbox(id) + r.mu.Lock() + r.notifSink = nil + r.mu.Unlock() + close(out) + r.promptMu.Unlock() + }() for { select { case <-ctx.Done(): - ch <- Event{Error: ctx.Err()} + out <- Event{Error: ctx.Err()} return - default: - } - - if !r.scanner.Scan() { - if err := r.scanner.Err(); err != nil { - ch <- Event{Error: fmt.Errorf("bridge read error: %w", err)} - } else { - ch <- Event{Error: fmt.Errorf("bridge process exited")} + case n, ok := <-notif: + if !ok { + out <- Event{Error: r.shutdownReason()} + return } - return - } - - var resp Response - if err := json.Unmarshal(r.scanner.Bytes(), &resp); err != nil { - if r.debug { - fmt.Fprintf(os.Stderr, "[hermes] skipping bad line: %s\n", r.scanner.Text()) + ev := translateNotification(n, r.debug) + if ev != nil { + out <- *ev + } + case resp, ok := <-inboxCh: + if !ok { + out <- Event{Error: r.shutdownReason()} + return } - continue - } - - // Response with matching ID means the prompt is complete. - if resp.ID == id { if resp.Error != nil { - ch <- Event{Error: resp.Error} + out <- Event{Error: resp.Error} return } + // Drain queued notifications first. The dispatcher + // emits deltas before the final response on the wire, + // but select picks between buffered channels at + // random — without this drain a fast bridge can hand + // us a "final" before we've flushed earlier deltas + // already sitting in `notif`. + drain: + for { + select { + case n, ok := <-notif: + if !ok { + break drain + } + if ev := translateNotification(n, r.debug); ev != nil { + out <- *ev + } + default: + break drain + } + } var result PromptResult if err := json.Unmarshal(resp.Result, &result); err != nil { - ch <- Event{Error: fmt.Errorf("parse prompt result: %w", err)} + out <- Event{Error: fmt.Errorf("parse prompt result: %w", err)} return } - ch <- Event{Type: "final", Final: &result} + out <- Event{Type: "final", Final: &result} return } - - // Notification (no ID, has Method). - if resp.Method != "" { - switch resp.Method { - case MethodMessageDelta: - var p MessageDeltaParams - if err := json.Unmarshal(resp.Params, &p); err == nil { - ch <- Event{Type: MethodMessageDelta, MessageDelta: &p} - } - case MethodToolCall: - var p ToolCallParams - if err := json.Unmarshal(resp.Params, &p); err == nil { - ch <- Event{Type: MethodToolCall, ToolCall: &p} - } - case MethodThought: - var p ThoughtParams - if err := json.Unmarshal(resp.Params, &p); err == nil { - ch <- Event{Type: MethodThought, Thought: &p} - } - default: - if r.debug { - fmt.Fprintf(os.Stderr, "[hermes] unknown notification: %s\n", resp.Method) - } - } - } } }() - return ch, nil + return out, nil +} + +// translateNotification turns a JSON-RPC notification into a typed Event. +// Returns nil for unknown methods (logged in debug mode). +func translateNotification(n *Response, debugMode bool) *Event { + switch n.Method { + case MethodMessageDelta: + var p MessageDeltaParams + if err := json.Unmarshal(n.Params, &p); err == nil { + return &Event{Type: MethodMessageDelta, MessageDelta: &p} + } + case MethodToolCall: + var p ToolCallParams + if err := json.Unmarshal(n.Params, &p); err == nil { + return &Event{Type: MethodToolCall, ToolCall: &p} + } + case MethodThought: + var p ThoughtParams + if err := json.Unmarshal(n.Params, &p); err == nil { + return &Event{Type: MethodThought, Thought: &p} + } + default: + if debugMode { + fmt.Fprintf(os.Stderr, "[hermes] unknown notification: %s\n", n.Method) + } + } + return nil } // PromptSync sends a prompt and blocks until the full response is received. @@ -354,21 +624,32 @@ func (r *Runner) PromptSync(ctx context.Context, text string) (string, error) { return sb.String(), nil } -// Stop gracefully shuts down the bridge process. +// Stop gracefully shuts down the bridge process. Safe to call from any +// goroutine; idempotent. func (r *Runner) Stop() error { - if !r.running { + if !r.running.CompareAndSwap(true, false) { return nil } - r.running = false // Close stdin to signal the bridge to exit. if r.stdin != nil { r.stdin.Close() } - // Clean up temp bridge file. + // Wait for the dispatcher to drain (scanner.Scan returns false when + // stdout closes). If it doesn't, the process kill below will force it. + if r.dispatchDone != nil { + select { + case <-r.dispatchDone: + case <-time.After(2 * time.Second): + } + } + + // Clean up temp bridge file even if the process never started fully — + // previously the early-error path leaked the temp file. if r.bridgeFile != "" { os.Remove(r.bridgeFile) + r.bridgeFile = "" } if r.cmd == nil || r.cmd.Process == nil { @@ -404,7 +685,26 @@ func (r *Runner) Stop() error { } } -// IsRunning returns whether the bridge process is alive. +// IsRunning returns whether the bridge process is alive. Cheap atomic +// read; safe to call from any goroutine. func (r *Runner) IsRunning() bool { - return r.running + return r.running.Load() +} + +// IsBridgeExitError reports whether err originated from the dispatcher +// noticing the bridge died (EOF, read error, or scanner failure). Used +// by cmd/talk to decide whether to restart the bridge or surface the +// error to the user. See clanker-cli #21. +func IsBridgeExitError(err error) bool { + return errors.Is(err, ErrBridgeExited) +} + +// lastLines returns up to n trailing lines of s, separated by `;` so +// the result fits cleanly on one error-line. +func lastLines(s string, n int) string { + lines := strings.Split(strings.TrimRight(s, "\n"), "\n") + if len(lines) > n { + lines = lines[len(lines)-n:] + } + return strings.Join(lines, "; ") } diff --git a/internal/hermes/runner_test.go b/internal/hermes/runner_test.go new file mode 100644 index 0000000..707b510 --- /dev/null +++ b/internal/hermes/runner_test.go @@ -0,0 +1,311 @@ +package hermes + +import ( + "context" + "encoding/json" + "errors" + "io" + "strings" + "sync" + "testing" + "time" +) + +// newTestRunner creates a Runner wired against caller-supplied pipes +// instead of an exec.Cmd. The bridge.py subprocess is replaced by a +// goroutine that reads request JSON from `in` and writes response JSON +// to `out`. Tests use this to exercise the dispatcher under -race +// without spawning Python. +func newTestRunner(t *testing.T, stdinWriter io.WriteCloser, stdoutReader io.Reader) *Runner { + t.Helper() + r := &Runner{ + debug: false, + stdin: stdinWriter, + sessionID: "test", + inbox: make(map[int]chan *Response), + nextID: 1, + } + if err := r.startWithStreams(stdoutReader); err != nil { + t.Fatalf("startWithStreams: %v", err) + } + r.running.Store(true) + return r +} + +// teeReader records every line read so the fake bridge in the test can +// know which IDs to respond to. Each test below uses a goroutine to +// scan stdin and synthesise responses + notifications. +type pipePair struct { + stdinR io.ReadCloser + stdinW io.WriteCloser + stdoutR io.Reader + stdoutW io.WriteCloser +} + +func newPipePair() pipePair { + sr, sw := io.Pipe() + or, ow := io.Pipe() + return pipePair{stdinR: sr, stdinW: sw, stdoutR: or, stdoutW: ow} +} + +// fakeBridge reads JSON-RPC requests off stdin and dispatches them via +// the supplied responder. Returns when stdin closes. +func fakeBridge(t *testing.T, p pipePair, responder func(req Request) []Response) { + t.Helper() + decoder := json.NewDecoder(p.stdinR) + for { + var req Request + if err := decoder.Decode(&req); err != nil { + return + } + for _, resp := range responder(req) { + data, err := json.Marshal(resp) + if err != nil { + t.Errorf("fakeBridge marshal: %v", err) + return + } + if _, err := p.stdoutW.Write(append(data, '\n')); err != nil { + return + } + } + } +} + +// TestDispatcher_ResponseRoutedByID — the load-bearing assertion from +// #20. Two requests in flight (call + prompt would race the scanner +// pre-fix). Here we exercise the routing layer by issuing back-to-back +// calls and verifying each gets its own response. +func TestDispatcher_ResponseRoutedByID(t *testing.T) { + p := newPipePair() + go fakeBridge(t, p, func(req Request) []Response { + // Echo a Result envelope with the request method tag. + return []Response{{ + ID: req.ID, + Result: jsonRaw(t, map[string]string{"echo": req.Method}), + }} + }) + + r := newTestRunner(t, p.stdinW, p.stdoutR) + defer func() { _ = p.stdinW.Close(); _ = p.stdoutW.Close() }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + resp1, err := r.call(ctx, "first", nil) + if err != nil { + t.Fatalf("call(first): %v", err) + } + resp2, err := r.call(ctx, "second", nil) + if err != nil { + t.Fatalf("call(second): %v", err) + } + if string(resp1.Result) == string(resp2.Result) { + t.Errorf("expected distinct responses, got identical: %s", resp1.Result) + } + if !strings.Contains(string(resp1.Result), `"first"`) { + t.Errorf("first response should contain method tag 'first', got: %s", resp1.Result) + } +} + +// TestDispatcher_PromptNotificationsThenFinal exercises the Prompt path: +// the bridge sends two MessageDelta notifications then a final response. +// The Event channel should yield deltas in order then a Final. +func TestDispatcher_PromptNotificationsThenFinal(t *testing.T) { + p := newPipePair() + go fakeBridge(t, p, func(req Request) []Response { + if req.Method != "prompt" { + return []Response{{ID: req.ID, Result: jsonRaw(t, map[string]any{})}} + } + return []Response{ + {Method: MethodMessageDelta, Params: jsonRaw(t, MessageDeltaParams{Text: "hello "})}, + {Method: MethodMessageDelta, Params: jsonRaw(t, MessageDeltaParams{Text: "world"})}, + {ID: req.ID, Result: jsonRaw(t, PromptResult{Text: "hello world"})}, + } + }) + + r := newTestRunner(t, p.stdinW, p.stdoutR) + defer func() { _ = p.stdinW.Close(); _ = p.stdoutW.Close() }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + ch, err := r.Prompt(ctx, "anything") + if err != nil { + t.Fatalf("Prompt: %v", err) + } + + var deltas []string + var final *PromptResult + for ev := range ch { + if ev.Error != nil { + t.Fatalf("event error: %v", ev.Error) + } + if ev.MessageDelta != nil { + deltas = append(deltas, ev.MessageDelta.Text) + } + if ev.Final != nil { + final = ev.Final + } + } + if got, want := strings.Join(deltas, ""), "hello world"; got != want { + t.Errorf("deltas joined = %q, want %q", got, want) + } + if final == nil { + t.Fatal("expected a Final event, got none") + } + if final.Text != "hello world" { + t.Errorf("final text = %q, want %q", final.Text, "hello world") + } +} + +// TestDispatcher_SerializesConcurrentPrompts proves the promptMu lock +// works. Two goroutines issue Prompt at once; the lock should make them +// run end-to-end one after the other rather than interleaving. Each +// prompt's deltas should arrive on its own channel with no cross-talk. +func TestDispatcher_SerializesConcurrentPrompts(t *testing.T) { + p := newPipePair() + go fakeBridge(t, p, func(req Request) []Response { + if req.Method != "prompt" { + return []Response{{ID: req.ID, Result: jsonRaw(t, map[string]any{})}} + } + paramsJSON, _ := json.Marshal(req.Params) + var pp PromptParams + _ = json.Unmarshal(paramsJSON, &pp) + return []Response{ + {Method: MethodMessageDelta, Params: jsonRaw(t, MessageDeltaParams{Text: "[" + pp.Text + "]"})}, + {ID: req.ID, Result: jsonRaw(t, PromptResult{Text: pp.Text})}, + } + }) + + r := newTestRunner(t, p.stdinW, p.stdoutR) + defer func() { _ = p.stdinW.Close(); _ = p.stdoutW.Close() }() + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + + results := make(chan string, 2) + var wg sync.WaitGroup + for _, text := range []string{"alpha", "beta"} { + wg.Add(1) + go func(text string) { + defer wg.Done() + out, err := r.PromptSync(ctx, text) + if err != nil { + t.Errorf("PromptSync(%s): %v", text, err) + return + } + results <- out + "|" + text + }(text) + } + wg.Wait() + close(results) + + seen := map[string]bool{} + for r := range results { + seen[r] = true + } + // Each goroutine should receive ITS OWN response — the dispatcher + // must not have crossed the streams. + if !seen["[alpha]|alpha"] { + t.Errorf("alpha goroutine did not receive its own delta+final (results: %v)", seen) + } + if !seen["[beta]|beta"] { + t.Errorf("beta goroutine did not receive its own delta+final (results: %v)", seen) + } +} + +// TestDispatcher_BridgeDeathUnblocksCallers — when the bridge stdout +// closes mid-call, every pending call() / Prompt should unblock with +// ErrBridgeExited. Pre-fix the awaiter looped forever on scanner.Scan +// returning false but never reaching the inbox channel. +func TestDispatcher_BridgeDeathUnblocksCallers(t *testing.T) { + p := newPipePair() + // Drain stdin (acting as a no-op bridge) so the request write + // doesn't block on an io.Pipe with no reader. + go func() { _, _ = io.Copy(io.Discard, p.stdinR) }() + r := newTestRunner(t, p.stdinW, p.stdoutR) + defer func() { _ = p.stdinW.Close() }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + errCh := make(chan error, 1) + go func() { + _, err := r.call(ctx, "ping", nil) + errCh <- err + }() + + // Give the call goroutine a moment to write its request. + time.Sleep(50 * time.Millisecond) + + // Simulate bridge dying: close stdout from "bridge" side. + _ = p.stdoutW.Close() + + select { + case err := <-errCh: + if !IsBridgeExitError(err) { + t.Errorf("expected ErrBridgeExited, got: %v", err) + } + case <-time.After(time.Second): + t.Fatal("call did not unblock after bridge death within 1s — regression of #20") + } +} + +// TestDispatcher_DroppedNotificationDoesNotBlock — if the notif sink +// buffer fills (rare in practice but possible if the consumer is slow), +// the dispatcher must drop rather than block. Otherwise one stuck +// consumer wedges the whole runner including the final-response path. +func TestDispatcher_DroppedNotificationDoesNotBlock(t *testing.T) { + p := newPipePair() + go fakeBridge(t, p, func(req Request) []Response { + if req.Method != "prompt" { + return []Response{{ID: req.ID, Result: jsonRaw(t, map[string]any{})}} + } + // Send 200 deltas (more than the 64-buffer notif sink) then final. + resps := make([]Response, 0, 201) + for range 200 { + resps = append(resps, Response{Method: MethodMessageDelta, Params: jsonRaw(t, MessageDeltaParams{Text: "x"})}) + } + resps = append(resps, Response{ID: req.ID, Result: jsonRaw(t, PromptResult{Text: "done"})}) + return resps + }) + + r := newTestRunner(t, p.stdinW, p.stdoutR) + defer func() { _ = p.stdinW.Close(); _ = p.stdoutW.Close() }() + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + + // The consumer (PromptSync) drains the channel so deltas flow. The + // concern here is that even under flood the final response still + // arrives — verifies the dispatcher's select-default drop semantics. + final, err := r.PromptSync(ctx, "flood") + if err != nil { + t.Fatalf("PromptSync under flood: %v", err) + } + if final == "" { + t.Error("expected text output from 200-delta flood") + } +} + +func TestIsBridgeExitError(t *testing.T) { + if !IsBridgeExitError(ErrBridgeExited) { + t.Error("ErrBridgeExited should be recognised by IsBridgeExitError") + } + wrapped := errors.New("oh no") + if IsBridgeExitError(wrapped) { + t.Error("plain error should not be recognised as bridge-exit") + } + if !IsBridgeExitError(errors.Join(ErrBridgeExited, errors.New("io error"))) { + t.Error("errors.Join with ErrBridgeExited should be recognised") + } +} + +func jsonRaw(t *testing.T, v any) json.RawMessage { + t.Helper() + b, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal: %v", err) + } + return b +}