diff --git a/providers/openai/language_model.go b/providers/openai/language_model.go index 231a77213..f614a460a 100644 --- a/providers/openai/language_model.go +++ b/providers/openai/language_model.go @@ -349,8 +349,7 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S if choice.FinishReason != "" { finishReason = choice.FinishReason } - switch { - case choice.Delta.Content != "": + if choice.Delta.Content != "" { if !isActiveText { isActiveText = true if !yield(fantasy.StreamPart{ @@ -367,7 +366,8 @@ func (o languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S }) { return } - case len(choice.Delta.ToolCalls) > 0: + } + if len(choice.Delta.ToolCalls) > 0 { if isActiveText { isActiveText = false if !yield(fantasy.StreamPart{ diff --git a/providers/openai/openai_test.go b/providers/openai/openai_test.go index 17d4af015..86aa14f0f 100644 --- a/providers/openai/openai_test.go +++ b/providers/openai/openai_test.go @@ -2288,6 +2288,16 @@ func (sms *streamingMockServer) prepareToolStreamResponse() { sms.chunks = chunks } +func (sms *streamingMockServer) prepareMixedContentToolStreamResponse() { + chunks := []string{ + `data: {"id":"chatcmpl-mixed","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":"thinking before tool","tool_calls":[{"index":0,"id":"call_mixed_content","type":"function","function":{"name":"test-tool","arguments":"{\"value\":\"mixed\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n", + `data: {"id":"chatcmpl-mixed","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n", + `data: {"id":"chatcmpl-mixed","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n", + "data: [DONE]\n\n", + } + sms.chunks = chunks +} + func (sms *streamingMockServer) prepareErrorStreamResponse() { chunks := []string{ `data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n", @@ -2566,6 +2576,74 @@ func TestDoStream(t *testing.T) { require.Equal(t, `{"value":"Sparkle Day"}`, fullInput.String()) }) + t.Run("should handle content and tool call deltas in the same chunk", func(t *testing.T) { + t.Parallel() + + server := newStreamingMockServer() + defer server.close() + + server.prepareMixedContentToolStreamResponse() + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), + ) + require.NoError(t, err) + model, _ := provider.LanguageModel(t.Context(), "gpt-3.5-turbo") + + stream, err := model.Stream(context.Background(), fantasy.Call{ + Prompt: testPrompt, + Tools: []fantasy.Tool{ + fantasy.FunctionTool{ + Name: "test-tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "value": map[string]any{ + "type": "string", + }, + }, + "required": []string{"value"}, + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#", + }, + }, + }, + }) + require.NoError(t, err) + + parts, err := collectStreamParts(stream) + require.NoError(t, err) + + var textDeltas []string + toolInputStart, toolInputEnd, toolCall := -1, -1, -1 + for i, part := range parts { + switch part.Type { + case fantasy.StreamPartTypeTextDelta: + textDeltas = append(textDeltas, part.Delta) + case fantasy.StreamPartTypeToolInputStart: + toolInputStart = i + require.Equal(t, "call_mixed_content", part.ID) + require.Equal(t, "test-tool", part.ToolCallName) + case fantasy.StreamPartTypeToolInputEnd: + toolInputEnd = i + require.Equal(t, "call_mixed_content", part.ID) + case fantasy.StreamPartTypeToolCall: + toolCall = i + require.Equal(t, "call_mixed_content", part.ID) + require.Equal(t, "test-tool", part.ToolCallName) + require.Equal(t, `{"value":"mixed"}`, part.ToolCallInput) + } + } + + require.Equal(t, []string{"thinking before tool"}, textDeltas) + require.NotEqual(t, -1, toolInputStart, "expected ToolInputStart part") + require.NotEqual(t, -1, toolInputEnd, "expected ToolInputEnd part") + require.NotEqual(t, -1, toolCall, "expected ToolCall part") + require.Less(t, toolInputStart, toolInputEnd) + require.Less(t, toolInputEnd, toolCall) + }) + t.Run("should handle tool calls with empty arguments", func(t *testing.T) { t.Parallel()