From f79733d37c86cc40517e4f2f8268855211dec9d6 Mon Sep 17 00:00:00 2001 From: Amos Bird Date: Tue, 31 Mar 2026 02:36:32 +0800 Subject: [PATCH 1/4] fix: stream idle timeout and goroutine leak in processStepStream - Add StreamIdleTimeout option to AgentStreamCall/AgentCall to cancel streams that stop sending data within a configurable duration - Fix goroutine leak: use sync.Once + defer to ensure close(toolChan) and toolExecutionWg.Wait() run on all return paths in processStepStream --- agent.go | 117 ++++++++++++++++++++++++++++++---------- agent_stream_test.go | 124 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 214 insertions(+), 27 deletions(-) diff --git a/agent.go b/agent.go index 34d4f8a1f..654c0291f 100644 --- a/agent.go +++ b/agent.go @@ -10,6 +10,8 @@ import ( "maps" "slices" "sync" + "sync/atomic" + "time" "charm.land/fantasy/schema" "github.com/charmbracelet/x/exp/slice" @@ -172,9 +174,10 @@ type AgentCall struct { OnRetry OnRetryCallback MaxRetries *int - StopWhen []StopCondition - PrepareStep PrepareStepFunction - RepairToolCall RepairToolCallFunction + StopWhen []StopCondition + PrepareStep PrepareStepFunction + RepairToolCall RepairToolCallFunction + StreamIdleTimeout time.Duration // Cancels the stream if no data arrives within this duration. } // Agent-level callbacks. @@ -263,9 +266,10 @@ type AgentStreamCall struct { OnRetry OnRetryCallback MaxRetries *int - StopWhen []StopCondition - PrepareStep PrepareStepFunction - RepairToolCall RepairToolCallFunction + StopWhen []StopCondition + PrepareStep PrepareStepFunction + RepairToolCall RepairToolCallFunction + StreamIdleTimeout time.Duration // Cancels the stream if no data arrives within this duration. // Agent-level callbacks OnAgentStart OnAgentStartFunc // Called when agent starts @@ -761,22 +765,23 @@ func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentT func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) { // Convert AgentStreamCall to AgentCall for preparation call := AgentCall{ - Prompt: opts.Prompt, - Files: opts.Files, - Messages: opts.Messages, - MaxOutputTokens: opts.MaxOutputTokens, - Temperature: opts.Temperature, - TopP: opts.TopP, - TopK: opts.TopK, - PresencePenalty: opts.PresencePenalty, - FrequencyPenalty: opts.FrequencyPenalty, - ActiveTools: opts.ActiveTools, - ProviderOptions: opts.ProviderOptions, - MaxRetries: opts.MaxRetries, - OnRetry: opts.OnRetry, - StopWhen: opts.StopWhen, - PrepareStep: opts.PrepareStep, - RepairToolCall: opts.RepairToolCall, + Prompt: opts.Prompt, + Files: opts.Files, + Messages: opts.Messages, + MaxOutputTokens: opts.MaxOutputTokens, + Temperature: opts.Temperature, + TopP: opts.TopP, + TopK: opts.TopK, + PresencePenalty: opts.PresencePenalty, + FrequencyPenalty: opts.FrequencyPenalty, + ActiveTools: opts.ActiveTools, + ProviderOptions: opts.ProviderOptions, + MaxRetries: opts.MaxRetries, + OnRetry: opts.OnRetry, + StopWhen: opts.StopWhen, + PrepareStep: opts.PrepareStep, + RepairToolCall: opts.RepairToolCall, + StreamIdleTimeout: opts.StreamIdleTimeout, } call = a.prepareCall(call) @@ -884,14 +889,28 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions) result, err := retry(ctx, func() (stepExecutionResult, error) { - // Create the stream - stream, err := stepModel.Stream(ctx, streamCall) + streamCtx := ctx + var streamCancel context.CancelFunc + if call.StreamIdleTimeout > 0 { + streamCtx, streamCancel = context.WithCancel(ctx) + } + + stream, err := stepModel.Stream(streamCtx, streamCall) if err != nil { + if streamCancel != nil { + streamCancel() + } return stepExecutionResult{}, err } - // Process the stream + if call.StreamIdleTimeout > 0 { + stream = withIdleTimeout(stream, call.StreamIdleTimeout, streamCancel) + } + result, err := a.processStepStream(ctx, stream, opts, steps, stepTools, stepExecProviderTools) + if streamCancel != nil { + streamCancel() + } if err != nil { return stepExecutionResult{}, err } @@ -1248,11 +1267,17 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op parallel bool } toolChan := make(chan toolExecutionRequest, 10) + var closeToolChan sync.Once var toolExecutionWg sync.WaitGroup var toolStateMu sync.Mutex toolResults := make([]ToolResultContent, 0) var toolExecutionErr error + defer func() { + closeToolChan.Do(func() { close(toolChan) }) + toolExecutionWg.Wait() + }() + // Create a map for quick tool lookup toolMap := make(map[string]AgentTool) for _, tool := range stepTools { @@ -1534,8 +1559,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op } } - // Close the tool execution channel and wait for all executions to complete - close(toolChan) + // Ensure the tool channel is closed and all tool executions complete. + // This is also handled by the deferred cleanup, but closing eagerly + // here allows us to inspect tool execution errors before returning. + closeToolChan.Do(func() { close(toolChan) }) toolExecutionWg.Wait() // Check for tool execution errors @@ -1602,3 +1629,39 @@ func WithProviderOptions(providerOptions ProviderOptions) AgentOption { s.providerOptions = providerOptions } } + +// withIdleTimeout wraps a StreamResponse so that if no stream part is +// received within the given timeout, cancelFn is called to cancel the +// underlying HTTP request context. Each received part resets the timer. +// The timer goroutine exits when the wrapped iterator returns. +func withIdleTimeout(stream StreamResponse, timeout time.Duration, cancelFn context.CancelFunc) StreamResponse { + return func(yield func(StreamPart) bool) { + timer := time.NewTimer(timeout) + done := make(chan struct{}) + var timedOut atomic.Bool + + go func() { + select { + case <-timer.C: + timedOut.Store(true) + cancelFn() + case <-done: + } + timer.Stop() + }() + + defer close(done) + + stream(func(part StreamPart) bool { + timer.Reset(timeout) + return yield(part) + }) + + if timedOut.Load() { + yield(StreamPart{ + Type: StreamPartTypeError, + Error: fmt.Errorf("stream idle timeout exceeded (%s)", timeout), + }) + } + } +} diff --git a/agent_stream_test.go b/agent_stream_test.go index ea9009305..459fd2073 100644 --- a/agent_stream_test.go +++ b/agent_stream_test.go @@ -3,8 +3,10 @@ package fantasy import ( "context" "encoding/json" + "errors" "fmt" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -595,3 +597,125 @@ func TestStreamingAgentSources(t *testing.T) { resultSources := result.Response.Content.Sources() require.Equal(t, 2, len(resultSources)) } + +// TestStreamingAgentIdleTimeout verifies that a hanging stream is cancelled +// after the idle timeout fires. +func TestStreamingAgentIdleTimeout(t *testing.T) { + t.Parallel() + + mockModel := &mockLanguageModel{ + streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) { + return func(yield func(StreamPart) bool) { + if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) { + return + } + // Simulate a hang: block until context is cancelled. + <-ctx.Done() + }, nil + }, + } + + agent := NewAgent(mockModel) + ctx := context.Background() + + streamCall := AgentStreamCall{ + Prompt: "Say hello", + StreamIdleTimeout: 100 * time.Millisecond, + } + + start := time.Now() + _, err := agent.Stream(ctx, streamCall) + elapsed := time.Since(start) + + require.Error(t, err) + require.Contains(t, err.Error(), "stream idle timeout exceeded") + require.Less(t, elapsed, 2*time.Second, "should not block for a long time") +} + +// TestStreamingAgentIdleTimeoutResetsOnChunks verifies that the idle timer +// resets with each chunk so a slow-but-active stream succeeds. +func TestStreamingAgentIdleTimeoutResetsOnChunks(t *testing.T) { + t.Parallel() + + mockModel := &mockLanguageModel{ + streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) { + return func(yield func(StreamPart) bool) { + if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) { + return + } + // Yield deltas with pauses shorter than the idle timeout. + for _, word := range []string{"Hello", ", ", "world", "!"} { + time.Sleep(30 * time.Millisecond) + if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: word}) { + return + } + } + if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) { + return + } + yield(StreamPart{ + Type: StreamPartTypeFinish, + Usage: Usage{InputTokens: 3, OutputTokens: 4, TotalTokens: 7}, + FinishReason: FinishReasonStop, + }) + }, nil + }, + } + + agent := NewAgent(mockModel) + ctx := context.Background() + + streamCall := AgentStreamCall{ + Prompt: "Say hello", + StreamIdleTimeout: 100 * time.Millisecond, + } + + result, err := agent.Stream(ctx, streamCall) + require.NoError(t, err) + require.Equal(t, "Hello, world!", result.Response.Content.Text()) +} + +// TestStreamingAgentCallbackErrorCleanup verifies that an early return from a +// callback error properly cleans up the tool coordinator goroutine (no leak). +func TestStreamingAgentCallbackErrorCleanup(t *testing.T) { + t.Parallel() + + callbackErr := errors.New("callback forced error") + + mockModel := &mockLanguageModel{ + streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) { + return func(yield func(StreamPart) bool) { + if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) { + return + } + yield(StreamPart{ + Type: StreamPartTypeFinish, + Usage: Usage{InputTokens: 1, OutputTokens: 1, TotalTokens: 2}, + FinishReason: FinishReasonStop, + }) + }, nil + }, + } + + agent := NewAgent(mockModel) + ctx := context.Background() + + streamCall := AgentStreamCall{ + Prompt: "Say hello", + OnTextDelta: func(_, _ string) error { + return callbackErr + }, + } + + _, err := agent.Stream(ctx, streamCall) + require.ErrorIs(t, err, callbackErr) +} From 68ffec162b663b937105b599fe711d7bc351065d Mon Sep 17 00:00:00 2001 From: Amos Bird Date: Fri, 3 Apr 2026 20:02:17 +0800 Subject: [PATCH 2/4] fix: prevent yield after consumer stop in withIdleTimeout When the consumer breaks out of the range loop (yield returns false) and an idle timeout fires concurrently, withIdleTimeout would call yield again after it already returned false, violating the range-over-function protocol and causing a runtime panic: runtime error: range function continued iteration after function for loop body returned false Track whether the consumer has stopped with a 'stopped' flag and skip the timeout error yield when it's already set. --- agent.go | 9 +++++++-- agent_stream_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/agent.go b/agent.go index 654c0291f..b848902ef 100644 --- a/agent.go +++ b/agent.go @@ -1652,12 +1652,17 @@ func withIdleTimeout(stream StreamResponse, timeout time.Duration, cancelFn cont defer close(done) + var stopped bool stream(func(part StreamPart) bool { timer.Reset(timeout) - return yield(part) + if !yield(part) { + stopped = true + return false + } + return true }) - if timedOut.Load() { + if timedOut.Load() && !stopped { yield(StreamPart{ Type: StreamPartTypeError, Error: fmt.Errorf("stream idle timeout exceeded (%s)", timeout), diff --git a/agent_stream_test.go b/agent_stream_test.go index 459fd2073..6d9582677 100644 --- a/agent_stream_test.go +++ b/agent_stream_test.go @@ -719,3 +719,46 @@ func TestStreamingAgentCallbackErrorCleanup(t *testing.T) { _, err := agent.Stream(ctx, streamCall) require.ErrorIs(t, err, callbackErr) } + +// TestStreamingAgentIdleTimeoutNoYieldAfterStop verifies that withIdleTimeout +// does not call yield after the consumer has stopped iteration, which would +// panic with "range function continued iteration after loop body returned false". +func TestStreamingAgentIdleTimeoutNoYieldAfterStop(t *testing.T) { + t.Parallel() + + mockModel := &mockLanguageModel{ + streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) { + return func(yield func(StreamPart) bool) { + if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) { + return + } + // Simulate a hang so the idle timeout fires while we're blocked. + <-ctx.Done() + // After context cancellation, the provider may still yield an error. + yield(StreamPart{ + Type: StreamPartTypeError, + Error: ctx.Err(), + }) + }, nil + }, + } + + agent := NewAgent(mockModel) + ctx := context.Background() + + streamCall := AgentStreamCall{ + Prompt: "Say hello", + StreamIdleTimeout: 50 * time.Millisecond, + OnTextDelta: func(_, _ string) error { + return errors.New("consumer error") + }, + } + + // This must not panic with "range function continued iteration after + // loop body returned false". + _, err := agent.Stream(ctx, streamCall) + require.Error(t, err) +} From f595a192675483482bc6fe726ddf24be2b37e048 Mon Sep 17 00:00:00 2001 From: Amos Bird Date: Fri, 3 Apr 2026 23:58:05 +0800 Subject: [PATCH 3/4] fix: prevent stream idle timeout from being misidentified as user cancel When streamIdleTimeout fires, it cancels the child streamCtx. The provider then yields a StreamPartTypeError carrying context.Canceled. The retry logic's isAbortError checked the error value (not the parent ctx), so it treated this as a user-initiated cancellation and aborted without retrying. Two fixes: - isAbortError now checks ctx.Err() instead of errors.Is(err, context.Canceled). If the parent context is still alive, the cancellation came from a child (idle timeout), not the user. - withIdleTimeout replaces context.Canceled errors with a descriptive "stream idle timeout exceeded" error when the timer has fired, so even if the error escapes the retry loop it won't be confused with a user cancel. --- agent.go | 10 +++++++++- agent_stream_test.go | 12 +++++++++--- errors.go | 4 ++++ retry.go | 19 +++++++++++++------ 4 files changed, 35 insertions(+), 10 deletions(-) diff --git a/agent.go b/agent.go index b848902ef..e705c48dc 100644 --- a/agent.go +++ b/agent.go @@ -1655,6 +1655,14 @@ func withIdleTimeout(stream StreamResponse, timeout time.Duration, cancelFn cont var stopped bool stream(func(part StreamPart) bool { timer.Reset(timeout) + // When the idle timeout fired and cancelled streamCtx, the + // provider will yield a StreamPartTypeError carrying + // context.Canceled. Replace it with a descriptive, + // non-context error so the retry logic treats it as a + // retryable failure rather than a user-initiated abort. + if part.Type == StreamPartTypeError && timedOut.Load() { + part.Error = fmt.Errorf("%w: no data received for %s", errStreamIdleTimeout, timeout) + } if !yield(part) { stopped = true return false @@ -1665,7 +1673,7 @@ func withIdleTimeout(stream StreamResponse, timeout time.Duration, cancelFn cont if timedOut.Load() && !stopped { yield(StreamPart{ Type: StreamPartTypeError, - Error: fmt.Errorf("stream idle timeout exceeded (%s)", timeout), + Error: fmt.Errorf("%w: no data received for %s", errStreamIdleTimeout, timeout), }) } } diff --git a/agent_stream_test.go b/agent_stream_test.go index 6d9582677..0bd821535 100644 --- a/agent_stream_test.go +++ b/agent_stream_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "sync/atomic" "testing" "time" @@ -599,12 +600,15 @@ func TestStreamingAgentSources(t *testing.T) { } // TestStreamingAgentIdleTimeout verifies that a hanging stream is cancelled -// after the idle timeout fires. +// after the idle timeout fires and that retries are attempted. func TestStreamingAgentIdleTimeout(t *testing.T) { t.Parallel() + var attempts atomic.Int32 + mockModel := &mockLanguageModel{ streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) { + attempts.Add(1) return func(yield func(StreamPart) bool) { if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) { return @@ -631,8 +635,10 @@ func TestStreamingAgentIdleTimeout(t *testing.T) { elapsed := time.Since(start) require.Error(t, err) - require.Contains(t, err.Error(), "stream idle timeout exceeded") - require.Less(t, elapsed, 2*time.Second, "should not block for a long time") + require.ErrorIs(t, err, errStreamIdleTimeout) + require.Less(t, elapsed, 10*time.Second, "should not block for a long time") + // Default retry is 2 retries, so 3 total attempts (1 initial + 2 retries). + require.Equal(t, int32(3), attempts.Load(), "should retry on idle timeout") } // TestStreamingAgentIdleTimeoutResetsOnChunks verifies that the idle timer diff --git a/errors.go b/errors.go index a925d4c43..7123cfe39 100644 --- a/errors.go +++ b/errors.go @@ -114,6 +114,10 @@ func ErrorTitleForStatusCode(statusCode int) string { return strings.ToLower(http.StatusText(statusCode)) } +// errStreamIdleTimeout is a sentinel used to identify stream-idle-timeout +// errors so the retry logic can treat them as retryable. +var errStreamIdleTimeout = errors.New("stream idle timeout") + // NoObjectGeneratedError is returned when object generation fails // due to parsing errors, validation errors, or model failures. type NoObjectGeneratedError struct { diff --git a/retry.go b/retry.go index 6b2c2a412..9a3689b1a 100644 --- a/retry.go +++ b/retry.go @@ -51,9 +51,15 @@ func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Du return exponentialBackoffDelay } -// isAbortError checks if the error is a context cancellation error. -func isAbortError(err error) bool { - return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) +// isAbortError checks whether the caller's context has been cancelled +// (e.g. by the user) or the error is a deadline-exceeded. We inspect +// ctx.Err() rather than the error value for context.Canceled because a +// stream-idle-timeout cancels a *child* streamCtx, producing a +// context.Canceled error even though the parent ctx is still alive. +// Checking ctx.Err() avoids aborting retries on transient idle-timeout +// cancellations. +func isAbortError(ctx context.Context, err error) bool { + return ctx.Err() != nil || errors.Is(err, context.DeadlineExceeded) } // RetryWithExponentialBackoffRespectingRetryHeaders creates a retry function that retries @@ -94,7 +100,7 @@ func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], opti return result, nil } - if isAbortError(err) { + if isAbortError(ctx, err) { return zero, err // don't retry when the request was aborted } @@ -110,9 +116,10 @@ func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], opti } var providerErr *ProviderError - if errors.As(err, &providerErr) && providerErr.IsRetryable() && tryNumber <= options.MaxRetries { + isIdleTimeout := errors.Is(err, errStreamIdleTimeout) + if (errors.As(err, &providerErr) && providerErr.IsRetryable() || isIdleTimeout) && tryNumber <= options.MaxRetries { delay := getRetryDelayInMs(err, options.InitialDelayIn) - if options.OnRetry != nil { + if options.OnRetry != nil && providerErr != nil { options.OnRetry(providerErr, delay) } From 803c63dd31cd21bd3c7e8bf10ab2b015585e178a Mon Sep 17 00:00:00 2001 From: Amos Bird Date: Mon, 6 Apr 2026 21:51:49 +0800 Subject: [PATCH 4/4] fix: increase default retries and allow OnRetry for idle timeouts - Increase default MaxRetries from 2 to 5 for better resilience during long tool-call parameter streaming (e.g. Write tool with large content) - Call OnRetry callback for idle timeout errors too (providerErr may be nil in that case), so callers can clean up partial state on retry - Update idle timeout test to verify retry count with MaxRetries=2 --- agent_stream_test.go | 6 +++++- retry.go | 5 ++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/agent_stream_test.go b/agent_stream_test.go index 0bd821535..b973434f4 100644 --- a/agent_stream_test.go +++ b/agent_stream_test.go @@ -628,6 +628,7 @@ func TestStreamingAgentIdleTimeout(t *testing.T) { streamCall := AgentStreamCall{ Prompt: "Say hello", StreamIdleTimeout: 100 * time.Millisecond, + MaxRetries: ptrTo(2), } start := time.Now() @@ -636,8 +637,9 @@ func TestStreamingAgentIdleTimeout(t *testing.T) { require.Error(t, err) require.ErrorIs(t, err, errStreamIdleTimeout) + // 2 retries with 2s initial delay and 2x backoff (2+4 = 6s). require.Less(t, elapsed, 10*time.Second, "should not block for a long time") - // Default retry is 2 retries, so 3 total attempts (1 initial + 2 retries). + // 3 total attempts (1 initial + 2 retries). require.Equal(t, int32(3), attempts.Load(), "should retry on idle timeout") } @@ -768,3 +770,5 @@ func TestStreamingAgentIdleTimeoutNoYieldAfterStop(t *testing.T) { _, err := agent.Stream(ctx, streamCall) require.Error(t, err) } + +func ptrTo[T any](v T) *T { return &v } diff --git a/retry.go b/retry.go index 9a3689b1a..c9e49d04e 100644 --- a/retry.go +++ b/retry.go @@ -82,11 +82,10 @@ type RetryOptions struct { // OnRetryCallback defines a function that is called when a retry occurs. type OnRetryCallback = func(err *ProviderError, delay time.Duration) -// DefaultRetryOptions returns the default retry options. // DefaultRetryOptions returns the default retry options. func DefaultRetryOptions() RetryOptions { return RetryOptions{ - MaxRetries: 2, + MaxRetries: 5, InitialDelayIn: 2000 * time.Millisecond, BackoffFactor: 2.0, } @@ -119,7 +118,7 @@ func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], opti isIdleTimeout := errors.Is(err, errStreamIdleTimeout) if (errors.As(err, &providerErr) && providerErr.IsRetryable() || isIdleTimeout) && tryNumber <= options.MaxRetries { delay := getRetryDelayInMs(err, options.InitialDelayIn) - if options.OnRetry != nil && providerErr != nil { + if options.OnRetry != nil { options.OnRetry(providerErr, delay) }