diff --git a/internal/adapter/claude/handler_stream_test.go b/internal/adapter/claude/handler_stream_test.go index 3d574fe..73a4c48 100644 --- a/internal/adapter/claude/handler_stream_test.go +++ b/internal/adapter/claude/handler_stream_test.go @@ -169,6 +169,17 @@ func TestHandleClaudeStreamRealtimeToolSafety(t *testing.T) { if !foundToolUse { t.Fatalf("expected tool_use block in stream, body=%s", rec.Body.String()) } + foundInputDelta := false + for _, f := range findClaudeFrames(frames, "content_block_delta") { + delta, _ := f.Payload["delta"].(map[string]any) + if delta["type"] == "input_json_delta" && strings.Contains(asString(delta["partial_json"]), `"q":"go"`) { + foundInputDelta = true + break + } + } + if !foundInputDelta { + t.Fatalf("expected input_json_delta with tool arguments, body=%s", rec.Body.String()) + } foundToolUseStop := false for _, f := range findClaudeFrames(frames, "message_delta") { diff --git a/internal/adapter/claude/standard_request.go b/internal/adapter/claude/standard_request.go index 23520c0..488bdf6 100644 --- a/internal/adapter/claude/standard_request.go +++ b/internal/adapter/claude/standard_request.go @@ -38,6 +38,9 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma } finalPrompt := deepseek.MessagesPrepare(toMessageMaps(dsPayload["messages"])) toolNames := extractClaudeToolNames(toolsRequested) + if len(toolNames) == 0 && len(toolsRequested) > 0 { + toolNames = []string{"__any_tool__"} + } return claudeNormalizedRequest{ Standard: util.StandardRequest{ diff --git a/internal/adapter/claude/stream_runtime_core.go b/internal/adapter/claude/stream_runtime_core.go index fead90a..a3dd649 100644 --- a/internal/adapter/claude/stream_runtime_core.go +++ b/internal/adapter/claude/stream_runtime_core.go @@ -8,7 +8,6 @@ import ( "ds2api/internal/sse" streamengine "ds2api/internal/stream" - "ds2api/internal/util" ) type claudeStreamRuntime struct { @@ -120,15 +119,6 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse if hasUnclosedCodeFence(s.text.String()) { continue } - detected := util.ParseToolCalls(s.text.String(), s.toolNames) - if len(detected) > 0 { - s.finalize("tool_use") - return streamengine.ParsedDecision{ - ContentSeen: true, - Stop: true, - StopReason: streamengine.StopReason("tool_use_detected"), - } - } continue } s.closeThinkingBlock() diff --git a/internal/adapter/claude/stream_runtime_finalize.go b/internal/adapter/claude/stream_runtime_finalize.go index 18a9e2d..493e491 100644 --- a/internal/adapter/claude/stream_runtime_finalize.go +++ b/internal/adapter/claude/stream_runtime_finalize.go @@ -1,6 +1,7 @@ package claude import ( + "encoding/json" "fmt" "time" @@ -53,6 +54,7 @@ func (s *claudeStreamRuntime) finalize(stopReason string) { stopReason = "tool_use" for i, tc := range detected { idx := s.nextBlockIndex + i + inputJSON, _ := json.Marshal(tc.Input) s.send("content_block_start", map[string]any{ "type": "content_block_start", "index": idx, @@ -60,7 +62,15 @@ func (s *claudeStreamRuntime) finalize(stopReason string) { "type": "tool_use", "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx), "name": tc.Name, - "input": tc.Input, + "input": map[string]any{}, + }, + }) + s.send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": idx, + "delta": map[string]any{ + "type": "input_json_delta", + "partial_json": string(inputJSON), }, }) s.send("content_block_stop", map[string]any{ diff --git a/internal/adapter/openai/handler_toolcall_format.go b/internal/adapter/openai/handler_toolcall_format.go index 4c38454..76f16fd 100644 --- a/internal/adapter/openai/handler_toolcall_format.go +++ b/internal/adapter/openai/handler_toolcall_format.go @@ -111,28 +111,21 @@ func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNam if len(deltas) == 0 { return nil } - allowed := namesToSet(allowedNames) - if len(allowed) == 0 { - for _, d := range deltas { - if d.Name != "" { - seenNames[d.Index] = "__blocked__" - } - } - return nil - } out := make([]toolCallDelta, 0, len(deltas)) for _, d := range deltas { if d.Name != "" { - if _, ok := allowed[d.Name]; !ok { - seenNames[d.Index] = "__blocked__" - continue + if seenNames != nil { + seenNames[d.Index] = d.Name } - seenNames[d.Index] = d.Name + out = append(out, d) + continue + } + if seenNames == nil { out = append(out, d) continue } name := strings.TrimSpace(seenNames[d.Index]) - if name == "" || name == "__blocked__" { + if name == "" { continue } out = append(out, d) diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index d3b849a..41b4c9f 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -182,7 +182,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) { } } -func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) { +func TestHandleNonStreamUnknownToolIntercepted(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`, @@ -198,16 +198,13 @@ func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) { out := decodeJSONBody(t, rec.Body.String()) choices, _ := out["choices"].([]any) choice, _ := choices[0].(map[string]any) - if choice["finish_reason"] != "stop" { - t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"]) + if choice["finish_reason"] != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"]) } msg, _ := choice["message"].(map[string]any) - if _, ok := msg["tool_calls"]; ok { - t.Fatalf("did not expect tool_calls for unknown schema name, got %#v", msg["tool_calls"]) - } - content, _ := msg["content"].(string) - if !strings.Contains(content, `"tool_calls"`) { - t.Fatalf("expected unknown tool json to pass through as text, got %#v", content) + toolCalls, _ := msg["tool_calls"].([]any) + if len(toolCalls) != 1 { + t.Fatalf("expected tool_calls for unknown schema name, got %#v", msg["tool_calls"]) } } @@ -413,7 +410,7 @@ func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing. } } -func TestHandleStreamUnknownToolDoesNotLeakRawPayload(t *testing.T) { +func TestHandleStreamUnknownToolEmitsToolCall(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`, @@ -428,18 +425,18 @@ func TestHandleStreamUnknownToolDoesNotLeakRawPayload(t *testing.T) { if !done { t.Fatalf("expected [DONE], body=%s", rec.Body.String()) } - if streamHasToolCallsDelta(frames) { - t.Fatalf("did not expect tool_calls delta for unknown schema name, body=%s", rec.Body.String()) + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta for unknown schema name, body=%s", rec.Body.String()) } if streamHasRawToolJSONContent(frames) { t.Fatalf("did not expect raw tool_calls json leak for unknown schema name: %s", rec.Body.String()) } - if streamFinishReason(frames) != "stop" { - t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String()) + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) } } -func TestHandleStreamUnknownToolNoArgsDoesNotLeakRawPayload(t *testing.T) { +func TestHandleStreamUnknownToolNoArgsEmitsToolCall(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\"}]}"}`, @@ -454,14 +451,14 @@ func TestHandleStreamUnknownToolNoArgsDoesNotLeakRawPayload(t *testing.T) { if !done { t.Fatalf("expected [DONE], body=%s", rec.Body.String()) } - if streamHasToolCallsDelta(frames) { - t.Fatalf("did not expect tool_calls delta for unknown schema name (no args), body=%s", rec.Body.String()) + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta for unknown schema name (no args), body=%s", rec.Body.String()) } if streamHasRawToolJSONContent(frames) { t.Fatalf("did not expect raw tool_calls json leak for unknown schema name (no args): %s", rec.Body.String()) } - if streamFinishReason(frames) != "stop" { - t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String()) + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) } } diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go index 02d1f4b..fb11ca1 100644 --- a/internal/adapter/openai/responses_stream_test.go +++ b/internal/adapter/openai/responses_stream_test.go @@ -354,7 +354,7 @@ func TestHandleResponsesStreamThinkingAndMixedToolExampleEmitsFunctionCall(t *te } } -func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) { +func TestHandleResponsesStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) { h := &Handler{} req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) rec := httptest.NewRecorder() @@ -376,8 +376,8 @@ func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) { h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, policy, "") body := rec.Body.String() - if strings.Contains(body, "event: response.function_call_arguments.done") { - t.Fatalf("did not expect function_call events for tool_choice=none, body=%s", body) + if !strings.Contains(body, "event: response.function_call_arguments.done") { + t.Fatalf("expected function_call events for tool_choice=none, body=%s", body) } } @@ -518,7 +518,7 @@ func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) { } } -func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) { +func TestHandleResponsesStreamAllowsUnknownToolName(t *testing.T) { h := &Handler{} req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) rec := httptest.NewRecorder() @@ -539,8 +539,8 @@ func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) { h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") body := rec.Body.String() - if strings.Contains(body, "event: response.function_call_arguments.done") { - t.Fatalf("did not expect function_call events for unknown tool, body=%s", body) + if !strings.Contains(body, "event: response.function_call_arguments.done") { + t.Fatalf("expected function_call events for unknown tool, body=%s", body) } } @@ -597,7 +597,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t } } -func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) { +func TestHandleResponsesNonStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) { h := &Handler{} rec := httptest.NewRecorder() resp := &http.Response{ @@ -611,16 +611,20 @@ func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, policy, "") if rec.Code != http.StatusOK { - t.Fatalf("expected 200 for tool_choice=none passthrough text, got %d body=%s", rec.Code, rec.Body.String()) + t.Fatalf("expected 200 for tool_choice=none handling, got %d body=%s", rec.Code, rec.Body.String()) } out := decodeJSONBody(t, rec.Body.String()) output, _ := out["output"].([]any) + foundFunctionCall := false for _, item := range output { m, _ := item.(map[string]any) if m != nil && m["type"] == "function_call" { - t.Fatalf("did not expect function_call output item for tool_choice=none, got %#v", output) + foundFunctionCall = true } } + if !foundFunctionCall { + t.Fatalf("expected function_call output item for tool_choice=none, got %#v", output) + } } func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) { diff --git a/internal/adapter/openai/standard_request.go b/internal/adapter/openai/standard_request.go index 1ba957c..af382cb 100644 --- a/internal/adapter/openai/standard_request.go +++ b/internal/adapter/openai/standard_request.go @@ -25,6 +25,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID } toolPolicy := util.DefaultToolChoicePolicy() finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy) + toolNames = ensureToolDetectionEnabled(toolNames, req["tools"]) passThrough := collectOpenAIChatPassThrough(req) return util.StandardRequest{ @@ -74,10 +75,8 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra return util.StandardRequest{}, err } finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy) - if toolPolicy.IsNone() { - toolNames = nil - toolPolicy.Allowed = nil - } else { + toolNames = ensureToolDetectionEnabled(toolNames, req["tools"]) + if !toolPolicy.IsNone() { toolPolicy.Allowed = namesToSet(toolNames) } passThrough := collectOpenAIChatPassThrough(req) @@ -98,6 +97,20 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra }, nil } +func ensureToolDetectionEnabled(toolNames []string, toolsRaw any) []string { + if len(toolNames) > 0 { + return toolNames + } + tools, _ := toolsRaw.([]any) + if len(tools) == 0 { + return toolNames + } + // Keep stream sieve/tool buffering enabled even when client tool schemas + // are malformed or lack explicit names; parsed tool payload names are no + // longer filtered by this list. + return []string{"__any_tool__"} +} + func collectOpenAIChatPassThrough(req map[string]any) map[string]any { out := map[string]any{} for _, k := range []string{ diff --git a/internal/adapter/openai/standard_request_test.go b/internal/adapter/openai/standard_request_test.go index e8d1225..45a3976 100644 --- a/internal/adapter/openai/standard_request_test.go +++ b/internal/adapter/openai/standard_request_test.go @@ -152,7 +152,7 @@ func TestNormalizeOpenAIResponsesRequestToolChoiceForcedUndeclaredFails(t *testi } } -func TestNormalizeOpenAIResponsesRequestToolChoiceNoneDisablesTools(t *testing.T) { +func TestNormalizeOpenAIResponsesRequestToolChoiceNoneKeepsToolDetectionEnabled(t *testing.T) { store := newEmptyStoreForNormalizeTest(t) req := map[string]any{ "model": "gpt-4o", @@ -174,7 +174,7 @@ func TestNormalizeOpenAIResponsesRequestToolChoiceNoneDisablesTools(t *testing.T if n.ToolChoice.Mode != util.ToolChoiceNone { t.Fatalf("expected tool choice mode none, got %q", n.ToolChoice.Mode) } - if len(n.ToolNames) != 0 { - t.Fatalf("expected no tool names when tool_choice=none, got %#v", n.ToolNames) + if len(n.ToolNames) == 0 { + t.Fatalf("expected tool detection sentinel when tool_choice=none, got %#v", n.ToolNames) } } diff --git a/internal/js/chat-stream/toolcall_policy.js b/internal/js/chat-stream/toolcall_policy.js index e881bab..4a3bbed 100644 --- a/internal/js/chat-stream/toolcall_policy.js +++ b/internal/js/chat-stream/toolcall_policy.js @@ -8,7 +8,10 @@ const { function resolveToolcallPolicy(prepBody, payloadTools) { const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names); - const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools); + let toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools); + if (toolNames.length === 0 && Array.isArray(payloadTools) && payloadTools.length > 0) { + toolNames = ['__any_tool__']; + } const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match); const emitEarlyToolDeltas = featureMatchEnabled && boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high); return { @@ -76,17 +79,6 @@ function filterIncrementalToolCallDeltasByAllowed(deltas, allowedNames, seenName return []; } const seen = seenNames instanceof Map ? seenNames : new Map(); - const allowed = new Set((allowedNames || []).filter((name) => asString(name) !== '')); - if (allowed.size === 0) { - for (const d of deltas) { - if (d && typeof d === 'object' && asString(d.name)) { - const index = Number.isInteger(d.index) ? d.index : 0; - seen.set(index, '__blocked__'); - } - } - return []; - } - const out = []; for (const d of deltas) { if (!d || typeof d !== 'object') { @@ -95,16 +87,12 @@ function filterIncrementalToolCallDeltasByAllowed(deltas, allowedNames, seenName const index = Number.isInteger(d.index) ? d.index : 0; const name = asString(d.name); if (name) { - if (!allowed.has(name)) { - seen.set(index, '__blocked__'); - continue; - } seen.set(index, name); out.push(d); continue; } const existing = asString(seen.get(index)); - if (!existing || existing === '__blocked__') { + if (!existing) { continue; } out.push(d); diff --git a/internal/js/helpers/stream-tool-sieve/parse.js b/internal/js/helpers/stream-tool-sieve/parse.js index 586c45b..d1d3e89 100644 --- a/internal/js/helpers/stream-tool-sieve/parse.js +++ b/internal/js/helpers/stream-tool-sieve/parse.js @@ -140,63 +140,17 @@ function emptyParseResult() { } function filterToolCallsDetailed(parsed, toolNames) { - const sourceNames = Array.isArray(toolNames) ? toolNames : []; - const allowed = new Set(); - const allowedCanonical = new Map(); - for (const item of sourceNames) { - const name = toStringSafe(item); - if (!name) { - continue; - } - allowed.add(name); - const lower = name.toLowerCase(); - if (!allowedCanonical.has(lower)) { - allowedCanonical.set(lower, name); - } - } - - if (allowed.size === 0) { - const rejected = []; - const seen = new Set(); - for (const tc of parsed) { - if (!tc || !tc.name) { - continue; - } - if (seen.has(tc.name)) { - continue; - } - seen.add(tc.name); - rejected.push(tc.name); - } - return { calls: [], rejectedToolNames: rejected }; - } - const calls = []; - const rejected = []; - const seenRejected = new Set(); for (const tc of parsed) { if (!tc || !tc.name) { continue; } - let matchedName = ''; - if (allowed.has(tc.name)) { - matchedName = tc.name; - } else { - matchedName = resolveAllowedToolName(tc.name, allowed, allowedCanonical); - } - if (!matchedName) { - if (!seenRejected.has(tc.name)) { - seenRejected.add(tc.name); - rejected.push(tc.name); - } - continue; - } calls.push({ - name: matchedName, + name: tc.name, input: tc.input && typeof tc.input === 'object' && !Array.isArray(tc.input) ? tc.input : {}, }); } - return { calls, rejectedToolNames: rejected }; + return { calls, rejectedToolNames: [] }; } function resolveAllowedToolName(name, allowed, allowedCanonical) { diff --git a/internal/js/helpers/stream-tool-sieve/parse_payload.js b/internal/js/helpers/stream-tool-sieve/parse_payload.js index dad52ab..d3438a6 100644 --- a/internal/js/helpers/stream-tool-sieve/parse_payload.js +++ b/internal/js/helpers/stream-tool-sieve/parse_payload.js @@ -318,6 +318,9 @@ function parseToolCallItem(m) { hasInput = true; } } + if (!name && typeof m.function === 'string') { + name = toStringSafe(m.function); + } if (!hasInput) { for (const k of ['arguments', 'args', 'parameters', 'params']) { diff --git a/internal/util/toolcalls_parse.go b/internal/util/toolcalls_parse.go index 7aad445..b52329c 100644 --- a/internal/util/toolcalls_parse.go +++ b/internal/util/toolcalls_parse.go @@ -16,6 +16,7 @@ type ToolCallParseResult struct { RejectedByPolicy bool RejectedToolNames []string } + func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall { return ParseToolCallsDetailed(text, availableToolNames).Calls } @@ -119,56 +120,17 @@ func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) } func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []string) ([]ParsedToolCall, []string) { - allowed := map[string]struct{}{} - allowedCanonical := map[string]string{} - for _, name := range availableToolNames { - trimmed := strings.TrimSpace(name) - if trimmed == "" { - continue - } - allowed[trimmed] = struct{}{} - lower := strings.ToLower(trimmed) - if _, exists := allowedCanonical[lower]; !exists { - allowedCanonical[lower] = trimmed - } - } - if len(allowed) == 0 { - rejectedSet := map[string]struct{}{} - rejected := make([]string, 0, len(parsed)) - for _, tc := range parsed { - if tc.Name == "" { - continue - } - if _, ok := rejectedSet[tc.Name]; ok { - continue - } - rejectedSet[tc.Name] = struct{}{} - rejected = append(rejected, tc.Name) - } - return nil, rejected - } out := make([]ParsedToolCall, 0, len(parsed)) - rejectedSet := map[string]struct{}{} - rejected := make([]string, 0) for _, tc := range parsed { if tc.Name == "" { continue } - matchedName := resolveAllowedToolName(tc.Name, allowed, allowedCanonical) - if matchedName == "" { - if _, ok := rejectedSet[tc.Name]; !ok { - rejectedSet[tc.Name] = struct{}{} - rejected = append(rejected, tc.Name) - } - continue - } - tc.Name = matchedName if tc.Input == nil { tc.Input = map[string]any{} } out = append(out, tc) } - return out, rejected + return out, nil } func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string { @@ -269,6 +231,11 @@ func parseToolCallItem(m map[string]any) (ParsedToolCall, bool) { } } } + if strings.TrimSpace(name) == "" { + if fnName, ok := m["function"].(string); ok { + name = fnName + } + } if !hasInput { for _, key := range []string{"arguments", "args", "parameters", "params"} { if v, ok := m[key]; ok { diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index 215d479..a761235 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -41,50 +41,64 @@ func TestParseToolCallsWithFunctionArgumentsString(t *testing.T) { } } -func TestParseToolCallsRejectsUnknownToolName(t *testing.T) { +func TestParseToolCallsWithFunctionStringAndArgumentsObject(t *testing.T) { + text := `{"tool_calls":[{"function":"Write","arguments":{"file_path":"tmp/a.md","content":"ok"}}]}` + calls := ParseToolCalls(text, nil) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %d", len(calls)) + } + if calls[0].Name != "Write" { + t.Fatalf("unexpected tool name: %s", calls[0].Name) + } + if calls[0].Input["file_path"] != "tmp/a.md" { + t.Fatalf("unexpected args: %#v", calls[0].Input) + } +} + +func TestParseToolCallsKeepsUnknownToolName(t *testing.T) { text := `{"tool_calls":[{"name":"unknown","input":{}}]}` calls := ParseToolCalls(text, []string{"search"}) - if len(calls) != 0 { - t.Fatalf("expected unknown tool to be rejected, got %#v", calls) + if len(calls) != 1 || calls[0].Name != "unknown" { + t.Fatalf("expected unknown tool to be preserved, got %#v", calls) } } -func TestParseToolCallsAllowsCaseInsensitiveToolNameAndCanonicalizes(t *testing.T) { +func TestParseToolCallsKeepsOriginalToolNameCase(t *testing.T) { text := `{"tool_calls":[{"name":"Bash","input":{"command":"ls -al"}}]}` calls := ParseToolCalls(text, []string{"bash"}) if len(calls) != 1 { t.Fatalf("expected 1 call, got %#v", calls) } - if calls[0].Name != "bash" { - t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + if calls[0].Name != "Bash" { + t.Fatalf("expected original tool name Bash, got %q", calls[0].Name) } } -func TestParseToolCallsDetailedMarksPolicyRejection(t *testing.T) { +func TestParseToolCallsDetailedDoesNotRejectByPolicy(t *testing.T) { text := `{"tool_calls":[{"name":"unknown","input":{}}]}` res := ParseToolCallsDetailed(text, []string{"search"}) if !res.SawToolCallSyntax { t.Fatalf("expected SawToolCallSyntax=true, got %#v", res) } - if !res.RejectedByPolicy { - t.Fatalf("expected RejectedByPolicy=true, got %#v", res) + if res.RejectedByPolicy { + t.Fatalf("expected RejectedByPolicy=false, got %#v", res) } - if len(res.Calls) != 0 { - t.Fatalf("expected no calls after policy rejection, got %#v", res.Calls) + if len(res.Calls) != 1 || res.Calls[0].Name != "unknown" { + t.Fatalf("expected call to be preserved, got %#v", res.Calls) } } -func TestParseToolCallsDetailedRejectsWhenAllowListEmpty(t *testing.T) { +func TestParseToolCallsDetailedAllowsWhenAllowListEmpty(t *testing.T) { text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` res := ParseToolCallsDetailed(text, nil) if !res.SawToolCallSyntax { t.Fatalf("expected SawToolCallSyntax=true, got %#v", res) } - if !res.RejectedByPolicy { - t.Fatalf("expected RejectedByPolicy=true, got %#v", res) + if res.RejectedByPolicy { + t.Fatalf("expected RejectedByPolicy=false, got %#v", res) } - if len(res.Calls) != 0 { - t.Fatalf("expected no calls when allow-list is empty, got %#v", res.Calls) + if len(res.Calls) != 1 || res.Calls[0].Name != "search" { + t.Fatalf("expected calls when allow-list is empty, got %#v", res.Calls) } } @@ -132,8 +146,8 @@ func TestParseToolCallsAllowsQualifiedToolName(t *testing.T) { if len(calls) != 1 { t.Fatalf("expected 1 call, got %#v", calls) } - if calls[0].Name != "search_web" { - t.Fatalf("expected canonical tool name search_web, got %q", calls[0].Name) + if calls[0].Name != "mcp.search_web" { + t.Fatalf("expected original tool name mcp.search_web, got %q", calls[0].Name) } } @@ -143,8 +157,8 @@ func TestParseToolCallsAllowsPunctuationVariantToolName(t *testing.T) { if len(calls) != 1 { t.Fatalf("expected 1 call, got %#v", calls) } - if calls[0].Name != "read_file" { - t.Fatalf("expected canonical tool name read_file, got %q", calls[0].Name) + if calls[0].Name != "read-file" { + t.Fatalf("expected original tool name read-file, got %q", calls[0].Name) } } @@ -154,8 +168,8 @@ func TestParseToolCallsSupportsClaudeXMLToolCall(t *testing.T) { if len(calls) != 1 { t.Fatalf("expected 1 call, got %#v", calls) } - if calls[0].Name != "bash" { - t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + if calls[0].Name != "Bash" { + t.Fatalf("expected original tool name Bash, got %q", calls[0].Name) } if calls[0].Input["command"] != "pwd" { t.Fatalf("expected command argument, got %#v", calls[0].Input) @@ -179,8 +193,8 @@ func TestParseToolCallsSupportsClaudeXMLJSONToolCall(t *testing.T) { if len(calls) != 1 { t.Fatalf("expected 1 call, got %#v", calls) } - if calls[0].Name != "bash" { - t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + if calls[0].Name != "Bash" { + t.Fatalf("expected original tool name Bash, got %q", calls[0].Name) } if calls[0].Input["command"] != "pwd" { t.Fatalf("expected command argument, got %#v", calls[0].Input) @@ -193,8 +207,8 @@ func TestParseToolCallsSupportsFunctionCallTagStyle(t *testing.T) { if len(calls) != 1 { t.Fatalf("expected 1 call, got %#v", calls) } - if calls[0].Name != "bash" { - t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + if calls[0].Name != "Bash" { + t.Fatalf("expected original tool name Bash, got %q", calls[0].Name) } if calls[0].Input["command"] != "ls -la" { t.Fatalf("expected command argument, got %#v", calls[0].Input) @@ -207,8 +221,8 @@ func TestParseToolCallsSupportsAntmlFunctionCallStyle(t *testing.T) { if len(calls) != 1 { t.Fatalf("expected 1 call, got %#v", calls) } - if calls[0].Name != "bash" { - t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + if calls[0].Name != "Bash" { + t.Fatalf("expected original tool name Bash, got %q", calls[0].Name) } if calls[0].Input["command"] != "pwd" { t.Fatalf("expected command argument, got %#v", calls[0].Input) @@ -221,8 +235,8 @@ func TestParseToolCallsSupportsAntmlArgumentStyle(t *testing.T) { if len(calls) != 1 { t.Fatalf("expected 1 call, got %#v", calls) } - if calls[0].Name != "bash" { - t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + if calls[0].Name != "Bash" { + t.Fatalf("expected original tool name Bash, got %q", calls[0].Name) } if calls[0].Input["command"] != "pwd" { t.Fatalf("expected command argument, got %#v", calls[0].Input) @@ -235,8 +249,8 @@ func TestParseToolCallsSupportsInvokeFunctionCallStyle(t *testing.T) { if len(calls) != 1 { t.Fatalf("expected 1 call, got %#v", calls) } - if calls[0].Name != "bash" { - t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + if calls[0].Name != "Bash" { + t.Fatalf("expected original tool name Bash, got %q", calls[0].Name) } if calls[0].Input["command"] != "pwd" { t.Fatalf("expected command argument, got %#v", calls[0].Input) @@ -263,8 +277,8 @@ func TestParseToolCallsSupportsNestedToolTagStyle(t *testing.T) { if len(calls) != 1 { t.Fatalf("expected 1 call, got %#v", calls) } - if calls[0].Name != "bash" { - t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + if calls[0].Name != "Bash" { + t.Fatalf("expected original tool name Bash, got %q", calls[0].Name) } if calls[0].Input["command"] != "pwd" { t.Fatalf("expected command argument, got %#v", calls[0].Input) @@ -277,8 +291,8 @@ func TestParseToolCallsSupportsAntmlFunctionAttributeWithParametersTag(t *testin if len(calls) != 1 { t.Fatalf("expected 1 call, got %#v", calls) } - if calls[0].Name != "bash" { - t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + if calls[0].Name != "Bash" { + t.Fatalf("expected original tool name Bash, got %q", calls[0].Name) } if calls[0].Input["command"] != "pwd" { t.Fatalf("expected command argument, got %#v", calls[0].Input) @@ -291,8 +305,8 @@ func TestParseToolCallsSupportsMultipleAntmlFunctionCalls(t *testing.T) { if len(calls) != 2 { t.Fatalf("expected 2 calls, got %#v", calls) } - if calls[0].Name != "bash" || calls[1].Name != "read" { - t.Fatalf("expected canonical names [bash read], got %#v", calls) + if calls[0].Name != "Bash" || calls[1].Name != "Read" { + t.Fatalf("expected original names [Bash Read], got %#v", calls) } } diff --git a/internal/util/util_edge_test.go b/internal/util/util_edge_test.go index 8113709..5d024a9 100644 --- a/internal/util/util_edge_test.go +++ b/internal/util/util_edge_test.go @@ -364,8 +364,8 @@ func TestFormatOpenAIStreamToolCalls(t *testing.T) { func TestParseToolCallsNoToolNames(t *testing.T) { text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` calls := ParseToolCalls(text, nil) - if len(calls) != 0 { - t.Fatalf("expected 0 call with nil tool names, got %d", len(calls)) + if len(calls) != 1 { + t.Fatalf("expected 1 call with nil tool names, got %d", len(calls)) } } diff --git a/tests/node/stream-tool-sieve.test.js b/tests/node/stream-tool-sieve.test.js index 23834ec..58881cd 100644 --- a/tests/node/stream-tool-sieve.test.js +++ b/tests/node/stream-tool-sieve.test.js @@ -55,33 +55,44 @@ test('parseToolCalls keeps non-object argument strings as _raw (Go parity)', () ]); }); -test('parseToolCalls drops unknown schema names when toolNames is provided', () => { +test('parseToolCalls supports function string + arguments object payload', () => { + const payload = JSON.stringify({ + tool_calls: [{ function: 'Write', arguments: { file_path: 'tmp/a.md', content: 'ok' } }], + }); + const calls = parseToolCalls(payload, []); + assert.equal(calls.length, 1); + assert.equal(calls[0].name, 'Write'); + assert.equal(calls[0].input.file_path, 'tmp/a.md'); +}); + +test('parseToolCalls keeps unknown schema names when toolNames is provided', () => { const payload = JSON.stringify({ tool_calls: [{ name: 'not_in_schema', input: { q: 'go' } }], }); const calls = parseToolCalls(payload, ['search']); - assert.equal(calls.length, 0); + assert.equal(calls.length, 1); + assert.equal(calls[0].name, 'not_in_schema'); }); -test('parseToolCalls matches tool name case-insensitively and canonicalizes', () => { +test('parseToolCalls keeps original tool name casing', () => { const payload = JSON.stringify({ tool_calls: [{ name: 'Read_File', input: { path: 'README.MD' } }], }); const calls = parseToolCalls(payload, ['read_file']); - assert.deepEqual(calls, [{ name: 'read_file', input: { path: 'README.MD' } }]); + assert.deepEqual(calls, [{ name: 'Read_File', input: { path: 'README.MD' } }]); }); -test('parseToolCalls rejects all names when toolNames is empty (Go strict parity)', () => { +test('parseToolCalls accepts all names when toolNames is empty', () => { const payload = JSON.stringify({ tool_calls: [{ name: 'not_in_schema', input: { q: 'go' } }], }); const calls = parseToolCalls(payload, []); - assert.equal(calls.length, 0); + assert.equal(calls.length, 1); const detailed = parseToolCallsDetailed(payload, []); assert.equal(detailed.sawToolCallSyntax, true); - assert.equal(detailed.rejectedByPolicy, true); - assert.deepEqual(detailed.rejectedToolNames, ['not_in_schema']); + assert.equal(detailed.rejectedByPolicy, false); + assert.deepEqual(detailed.rejectedToolNames, []); }); test('parseToolCalls ignores tool_call payloads that exist only inside fenced code blocks', () => { @@ -287,7 +298,7 @@ test('sieve preserves text spacing when TOOL_RESULT_HISTORY spans chunks', () => assert.equal(leakedText, 'Hello world'); }); -test('sieve intercepts rejected unknown tool payload (no args) without raw leak', () => { +test('sieve emits unknown tool payload (no args) as executable tool call', () => { const events = runSieve( ['{"tool_calls":[{"name":"not_in_schema"}]}', '后置正文G。'], ['read_file'], @@ -295,8 +306,7 @@ test('sieve intercepts rejected unknown tool payload (no args) without raw leak' const leakedText = collectText(events); const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0); const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0); - assert.equal(hasToolCall, false); - assert.equal(hasToolDelta, false); + assert.equal(hasToolCall || hasToolDelta, true); assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); assert.equal(leakedText.includes('后置正文G。'), true); });