@@ -226,6 +226,123 @@ def _tool_result_ids(msg: Dict[str, Any]) -> set[str]:
226226 return sanitized
227227
228228
229+ def sanitize_openai_tool_history (normalized_messages : List [Dict [str , Any ]]) -> List [Dict [str , Any ]]:
230+ """Normalize OpenAI chat-completions tool-call history.
231+
232+ Enforces strict pairing for OpenAI-compatible message sequences:
233+ 1. Drop assistant tool_calls that have no later matching role=tool response.
234+ 2. Drop role=tool messages that do not correspond to an earlier assistant tool_call.
235+ 3. If an assistant message mixes paired and unpaired tool_calls, keep only the paired subset.
236+ 4. Fold matching role=tool messages to immediately follow the assistant tool_call turn.
237+ """
238+ tool_response_indices : Dict [str , List [int ]] = {}
239+ for idx , message in enumerate (normalized_messages ):
240+ if message .get ("role" ) != "tool" :
241+ continue
242+ tool_call_id = str (message .get ("tool_call_id" ) or "" ).strip ()
243+ if tool_call_id :
244+ tool_response_indices .setdefault (tool_call_id , []).append (idx )
245+
246+ sanitized : List [Dict [str , Any ]] = []
247+ consumed_tool_indices : set [int ] = set ()
248+ i = 0
249+
250+ while i < len (normalized_messages ):
251+ message = normalized_messages [i ]
252+ role = message .get ("role" )
253+
254+ if role == "tool" :
255+ tool_call_id = str (message .get ("tool_call_id" ) or "" ).strip ()
256+ if i in consumed_tool_indices :
257+ i += 1
258+ continue
259+ logger .debug (
260+ "[provider_clients] Dropped orphan OpenAI tool response" ,
261+ extra = {"message_index" : i , "tool_call_id" : tool_call_id },
262+ )
263+ i += 1
264+ continue
265+
266+ if role != "assistant" :
267+ sanitized .append (message )
268+ i += 1
269+ continue
270+
271+ tool_calls = message .get ("tool_calls" )
272+ if not isinstance (tool_calls , list ) or not tool_calls :
273+ sanitized .append (message )
274+ i += 1
275+ continue
276+
277+ paired_tool_calls : List [Dict [str , Any ]] = []
278+ paired_ids : List [str ] = []
279+ for tool_call in tool_calls :
280+ if not isinstance (tool_call , dict ):
281+ continue
282+ tool_call_id = str (tool_call .get ("id" ) or "" ).strip ()
283+ if not tool_call_id :
284+ continue
285+ response_positions = tool_response_indices .get (tool_call_id , [])
286+ if any (response_idx > i and response_idx not in consumed_tool_indices for response_idx in response_positions ):
287+ paired_tool_calls .append (tool_call )
288+ paired_ids .append (tool_call_id )
289+
290+ if not paired_tool_calls :
291+ logger .debug (
292+ "[provider_clients] Dropped OpenAI assistant message with unpaired tool_calls" ,
293+ extra = {"message_index" : i },
294+ )
295+ i += 1
296+ continue
297+
298+ if len (paired_tool_calls ) != len (tool_calls ):
299+ logger .debug (
300+ "[provider_clients] Sanitized OpenAI assistant tool_calls to paired subset" ,
301+ extra = {
302+ "message_index" : i ,
303+ "before_count" : len (tool_calls ),
304+ "after_count" : len (paired_tool_calls ),
305+ },
306+ )
307+
308+ sanitized .append ({** message , "tool_calls" : paired_tool_calls })
309+
310+ expected_ids = set (paired_ids )
311+ seen_ids : set [str ] = set ()
312+ deferred_messages : List [Dict [str , Any ]] = []
313+ j = i + 1
314+ while j < len (normalized_messages ):
315+ next_message = normalized_messages [j ]
316+ next_role = next_message .get ("role" )
317+ if next_role == "assistant" :
318+ break
319+
320+ if next_role == "tool" :
321+ tool_call_id = str (next_message .get ("tool_call_id" ) or "" ).strip ()
322+ if tool_call_id in expected_ids and tool_call_id not in seen_ids :
323+ sanitized .append (next_message )
324+ consumed_tool_indices .add (j )
325+ seen_ids .add (tool_call_id )
326+ else :
327+ logger .debug (
328+ "[provider_clients] Dropped orphan or duplicate OpenAI tool response" ,
329+ extra = {"message_index" : j , "tool_call_id" : tool_call_id },
330+ )
331+ if expected_ids .issubset (seen_ids ):
332+ j += 1
333+ break
334+ j += 1
335+ continue
336+
337+ deferred_messages .append (next_message )
338+ j += 1
339+
340+ sanitized .extend (deferred_messages )
341+ i = j
342+
343+ return sanitized
344+
345+
229346def _retry_delay_seconds (attempt : int , base_delay : float = 0.5 , max_delay : float = 32.0 ) -> float :
230347 """Calculate exponential backoff with jitter."""
231348 capped_base : float = float (min (base_delay * (2 ** max (0 , attempt - 1 )), max_delay ))
0 commit comments