Skip to content

Commit 7e46658

Browse files
committed
feat(providers): enhance OpenAI tool history sanitization
Add sanitize_openai_tool_history function to enforce strict message pairing in OpenAI-compatible provider sequences: - Drops unpaired assistant tool_calls and orphan tool responses - Keeps only the paired subset when mixing paired/unpaired tool_calls - Preserves valid tool chains while removing invalid sequences - Includes comprehensive test coverage for various edge cases Also fix MCP schema variable name collisions detected by mypy Generated with Ripperdoc Co-Authored-By: Ripperdoc
1 parent 0310b20 commit 7e46658

5 files changed

Lines changed: 247 additions & 10 deletions

File tree

ripperdoc/core/providers/base.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,123 @@ def _tool_result_ids(msg: Dict[str, Any]) -> set[str]:
226226
return sanitized
227227

228228

229+
def sanitize_openai_tool_history(normalized_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
230+
"""Normalize OpenAI chat-completions tool-call history.
231+
232+
Enforces strict pairing for OpenAI-compatible message sequences:
233+
1. Drop assistant tool_calls that have no later matching role=tool response.
234+
2. Drop role=tool messages that do not correspond to an earlier assistant tool_call.
235+
3. If an assistant message mixes paired and unpaired tool_calls, keep only the paired subset.
236+
4. Fold matching role=tool messages to immediately follow the assistant tool_call turn.
237+
"""
238+
tool_response_indices: Dict[str, List[int]] = {}
239+
for idx, message in enumerate(normalized_messages):
240+
if message.get("role") != "tool":
241+
continue
242+
tool_call_id = str(message.get("tool_call_id") or "").strip()
243+
if tool_call_id:
244+
tool_response_indices.setdefault(tool_call_id, []).append(idx)
245+
246+
sanitized: List[Dict[str, Any]] = []
247+
consumed_tool_indices: set[int] = set()
248+
i = 0
249+
250+
while i < len(normalized_messages):
251+
message = normalized_messages[i]
252+
role = message.get("role")
253+
254+
if role == "tool":
255+
tool_call_id = str(message.get("tool_call_id") or "").strip()
256+
if i in consumed_tool_indices:
257+
i += 1
258+
continue
259+
logger.debug(
260+
"[provider_clients] Dropped orphan OpenAI tool response",
261+
extra={"message_index": i, "tool_call_id": tool_call_id},
262+
)
263+
i += 1
264+
continue
265+
266+
if role != "assistant":
267+
sanitized.append(message)
268+
i += 1
269+
continue
270+
271+
tool_calls = message.get("tool_calls")
272+
if not isinstance(tool_calls, list) or not tool_calls:
273+
sanitized.append(message)
274+
i += 1
275+
continue
276+
277+
paired_tool_calls: List[Dict[str, Any]] = []
278+
paired_ids: List[str] = []
279+
for tool_call in tool_calls:
280+
if not isinstance(tool_call, dict):
281+
continue
282+
tool_call_id = str(tool_call.get("id") or "").strip()
283+
if not tool_call_id:
284+
continue
285+
response_positions = tool_response_indices.get(tool_call_id, [])
286+
if any(response_idx > i and response_idx not in consumed_tool_indices for response_idx in response_positions):
287+
paired_tool_calls.append(tool_call)
288+
paired_ids.append(tool_call_id)
289+
290+
if not paired_tool_calls:
291+
logger.debug(
292+
"[provider_clients] Dropped OpenAI assistant message with unpaired tool_calls",
293+
extra={"message_index": i},
294+
)
295+
i += 1
296+
continue
297+
298+
if len(paired_tool_calls) != len(tool_calls):
299+
logger.debug(
300+
"[provider_clients] Sanitized OpenAI assistant tool_calls to paired subset",
301+
extra={
302+
"message_index": i,
303+
"before_count": len(tool_calls),
304+
"after_count": len(paired_tool_calls),
305+
},
306+
)
307+
308+
sanitized.append({**message, "tool_calls": paired_tool_calls})
309+
310+
expected_ids = set(paired_ids)
311+
seen_ids: set[str] = set()
312+
deferred_messages: List[Dict[str, Any]] = []
313+
j = i + 1
314+
while j < len(normalized_messages):
315+
next_message = normalized_messages[j]
316+
next_role = next_message.get("role")
317+
if next_role == "assistant":
318+
break
319+
320+
if next_role == "tool":
321+
tool_call_id = str(next_message.get("tool_call_id") or "").strip()
322+
if tool_call_id in expected_ids and tool_call_id not in seen_ids:
323+
sanitized.append(next_message)
324+
consumed_tool_indices.add(j)
325+
seen_ids.add(tool_call_id)
326+
else:
327+
logger.debug(
328+
"[provider_clients] Dropped orphan or duplicate OpenAI tool response",
329+
extra={"message_index": j, "tool_call_id": tool_call_id},
330+
)
331+
if expected_ids.issubset(seen_ids):
332+
j += 1
333+
break
334+
j += 1
335+
continue
336+
337+
deferred_messages.append(next_message)
338+
j += 1
339+
340+
sanitized.extend(deferred_messages)
341+
i = j
342+
343+
return sanitized
344+
345+
229346
def _retry_delay_seconds(attempt: int, base_delay: float = 0.5, max_delay: float = 32.0) -> float:
230347
"""Calculate exponential backoff with jitter."""
231348
capped_base: float = float(min(base_delay * (2 ** max(0, attempt - 1)), max_delay))

ripperdoc/core/providers/openai_non_oauth_strategies.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ProviderResponse,
2525
call_with_timeout_and_retries,
2626
iter_with_timeout,
27+
sanitize_openai_tool_history,
2728
sanitize_tool_history,
2829
)
2930
from ripperdoc.core.providers.openai_responses import (
@@ -499,9 +500,12 @@ async def call(
499500
default_headers: Optional[Dict[str, str]] = None,
500501
) -> ProviderResponse:
501502
openai_tools = await build_openai_tool_schemas(tools)
503+
sanitized_messages = sanitize_openai_tool_history(
504+
sanitize_tool_history(list(normalized_messages))
505+
)
502506
openai_messages: List[Dict[str, object]] = [
503507
{"role": "system", "content": system_prompt}
504-
] + sanitize_tool_history(list(normalized_messages))
508+
] + sanitized_messages
505509

506510
logger.debug(
507511
"[openai_client] Preparing request",
@@ -716,7 +720,9 @@ async def call(
716720
) -> ProviderResponse:
717721
openai_tools = await build_openai_tool_schemas(tools)
718722
response_tools = convert_chat_function_tools_to_responses_tools(openai_tools)
719-
sanitized_messages = sanitize_tool_history(list(normalized_messages))
723+
sanitized_messages = sanitize_openai_tool_history(
724+
sanitize_tool_history(list(normalized_messages))
725+
)
720726
response_input = build_input_from_normalized_messages(
721727
cast(List[Dict[str, Any]], sanitized_messages),
722728
assistant_text_type="output_text",

ripperdoc/core/providers/openai_oauth_codex.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ProgressCallback,
2222
ProviderResponse,
2323
call_with_timeout_and_retries,
24+
sanitize_openai_tool_history,
2425
sanitize_tool_history,
2526
)
2627
from ripperdoc.core.providers.error_mapping import (
@@ -147,7 +148,9 @@ async def call_oauth_codex(
147148
)
148149

149150
openai_tools = await build_openai_tool_schemas(tools)
150-
sanitized_messages = sanitize_tool_history(list(normalized_messages))
151+
sanitized_messages = sanitize_openai_tool_history(
152+
sanitize_tool_history(list(normalized_messages))
153+
)
151154
response_input = _build_codex_responses_input(
152155
cast(List[Dict[str, Any]], sanitized_messages),
153156
assistant_text_type="output_text",

ripperdoc/utils/mcp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ def _coerce_sdk_schema(value: Any) -> dict[str, Any]:
127127

128128
if hasattr(value, "model_json_schema") and callable(value.model_json_schema):
129129
try:
130-
schema = value.model_json_schema()
131-
if isinstance(schema, dict):
132-
return schema
130+
json_schema = value.model_json_schema()
131+
if isinstance(json_schema, dict):
132+
return json_schema
133133
except (TypeError, ValueError, AttributeError):
134134
pass
135135

@@ -153,10 +153,10 @@ def _coerce_sdk_schema(value: Any) -> dict[str, Any]:
153153
key_str = str(key)
154154
properties[key_str] = _coerce_sdk_schema(item)
155155
required.append(key_str)
156-
schema: dict[str, Any] = {"type": "object", "properties": properties}
156+
schema_dict: dict[str, Any] = {"type": "object", "properties": properties}
157157
if required:
158-
schema["required"] = required
159-
return schema
158+
schema_dict["required"] = required
159+
return schema_dict
160160

161161
if isinstance(value, (list, tuple)):
162162
items = _coerce_sdk_schema(value[0]) if len(value) == 1 else {}

tests/test_messages.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
is_hidden_meta_message,
8989
normalize_messages_for_api,
9090
)
91-
from ripperdoc.core.providers.base import sanitize_tool_history
91+
from ripperdoc.core.providers.base import sanitize_openai_tool_history, sanitize_tool_history
9292

9393

9494
def test_create_user_message():
@@ -763,6 +763,117 @@ def test_sanitize_tool_history_replays_real_session_parallel_git_tool_calls():
763763
]
764764

765765

766+
def test_sanitize_openai_tool_history_drops_unpaired_assistant_tool_calls():
767+
normalized = [
768+
{
769+
"role": "assistant",
770+
"content": None,
771+
"tool_calls": [
772+
{
773+
"id": "call_1",
774+
"type": "function",
775+
"function": {"name": "Read", "arguments": '{"path":"README.md"}'},
776+
}
777+
],
778+
}
779+
]
780+
781+
sanitized = sanitize_openai_tool_history(normalized)
782+
783+
assert sanitized == []
784+
785+
786+
def test_sanitize_openai_tool_history_drops_orphan_tool_messages():
787+
normalized = [
788+
{"role": "user", "content": "hello"},
789+
{"role": "tool", "tool_call_id": "orphan_call", "content": "result"},
790+
]
791+
792+
sanitized = sanitize_openai_tool_history(normalized)
793+
794+
assert sanitized == [{"role": "user", "content": "hello"}]
795+
796+
797+
def test_sanitize_openai_tool_history_keeps_only_paired_assistant_tool_calls():
798+
normalized = [
799+
{
800+
"role": "assistant",
801+
"content": "running tools",
802+
"tool_calls": [
803+
{
804+
"id": "call_1",
805+
"type": "function",
806+
"function": {"name": "Read", "arguments": '{"path":"README.md"}'},
807+
},
808+
{
809+
"id": "call_2",
810+
"type": "function",
811+
"function": {"name": "Glob", "arguments": '{"pattern":"*.py"}'},
812+
},
813+
],
814+
},
815+
{"role": "tool", "tool_call_id": "call_2", "content": "matched files"},
816+
]
817+
818+
sanitized = sanitize_openai_tool_history(normalized)
819+
820+
assert len(sanitized) == 2
821+
assert sanitized[0]["role"] == "assistant"
822+
assert [call["id"] for call in sanitized[0]["tool_calls"]] == ["call_2"]
823+
assert sanitized[1] == {"role": "tool", "tool_call_id": "call_2", "content": "matched files"}
824+
825+
826+
def test_sanitize_openai_tool_history_preserves_valid_tool_chain():
827+
normalized = [
828+
{"role": "user", "content": "read the file"},
829+
{
830+
"role": "assistant",
831+
"content": None,
832+
"tool_calls": [
833+
{
834+
"id": "call_1",
835+
"type": "function",
836+
"function": {"name": "Read", "arguments": '{"path":"README.md"}'},
837+
}
838+
],
839+
},
840+
{"role": "tool", "tool_call_id": "call_1", "content": "file contents"},
841+
{"role": "assistant", "content": "done"},
842+
]
843+
844+
sanitized = sanitize_openai_tool_history(normalized)
845+
846+
assert sanitized == normalized
847+
848+
849+
def test_sanitize_openai_tool_history_reorders_intervening_messages_after_tool_results():
850+
normalized = [
851+
{
852+
"role": "assistant",
853+
"content": None,
854+
"tool_calls": [
855+
{
856+
"id": "call_1",
857+
"type": "function",
858+
"function": {"name": "Read", "arguments": '{"path":"README.md"}'},
859+
}
860+
],
861+
},
862+
{"role": "user", "content": "PreToolUse:Read hook additional context"},
863+
{"role": "tool", "tool_call_id": "call_1", "content": "file contents"},
864+
{"role": "assistant", "content": "done"},
865+
]
866+
867+
sanitized = sanitize_openai_tool_history(normalized)
868+
869+
assert sanitized == [
870+
normalized[0],
871+
normalized[2],
872+
normalized[1],
873+
normalized[3],
874+
]
875+
876+
766877
def test_normalize_messages_with_reasoning_metadata():
767878
"""Ensure reasoning metadata is preserved for OpenAI-style messages."""
768879
assistant = create_assistant_message(

0 commit comments

Comments
 (0)