@@ -456,21 +456,64 @@ def _group_tool_calls_by_concurrency(prepared_calls: List[Dict[str, Any]]) -> Li
456456
457457
458458async def _execute_tools_sequentially (
459- items : List [Dict [str , Any ]], tool_results : List [UserMessage ]
459+ items : List [Dict [str , Any ]],
460+ tool_results : List [UserMessage ],
461+ abort_signal : Optional [asyncio .Event ] = None ,
460462) -> AsyncGenerator [Union [UserMessage , ProgressMessage ], None ]:
461463 """Run tool generators one by one."""
462464 for item in items :
463465 gen = item .get ("generator" )
464466 if not gen :
465467 continue
466- async for message in gen :
468+ if abort_signal is None :
469+ async for message in gen :
470+ if isinstance (message , UserMessage ):
471+ tool_results .append (message )
472+ yield message
473+ continue
474+
475+ while True :
476+ if abort_signal .is_set ():
477+ await gen .aclose ()
478+ return
479+
480+ next_item = asyncio .create_task (gen .__anext__ ())
481+ abort_waiter = asyncio .create_task (abort_signal .wait ())
482+ done , pending = await asyncio .wait (
483+ {next_item , abort_waiter },
484+ return_when = asyncio .FIRST_COMPLETED ,
485+ )
486+
487+ for pending_task in pending :
488+ pending_task .cancel ()
489+ try :
490+ await pending_task
491+ except asyncio .CancelledError :
492+ pass
493+
494+ if abort_waiter in done and abort_signal .is_set ():
495+ next_item .cancel ()
496+ try :
497+ await next_item
498+ except (asyncio .CancelledError , StopAsyncIteration ):
499+ pass
500+ await gen .aclose ()
501+ return
502+
503+ try :
504+ message = next_item .result ()
505+ except StopAsyncIteration :
506+ break
507+
467508 if isinstance (message , UserMessage ):
468509 tool_results .append (message )
469510 yield message
470511
471512
472513async def _execute_tools_in_parallel (
473- items : List [Dict [str , Any ]], tool_results : List [UserMessage ]
514+ items : List [Dict [str , Any ]],
515+ tool_results : List [UserMessage ],
516+ abort_signal : Optional [asyncio .Event ] = None ,
474517) -> AsyncGenerator [Union [UserMessage , ProgressMessage ], None ]:
475518 """Run tool generators concurrently."""
476519 logger .debug ("[query] _execute_tools_in_parallel ENTER: %d items" , len (items ))
@@ -482,42 +525,65 @@ async def _execute_tools_in_parallel(
482525 len (generators ),
483526 tool_names ,
484527 )
485- async for message in _run_concurrent_tool_uses (generators , tool_names , tool_results ):
528+ async for message in _run_concurrent_tool_uses (
529+ generators ,
530+ tool_names ,
531+ tool_results ,
532+ abort_signal = abort_signal ,
533+ ):
486534 yield message
487535 logger .debug ("[query] _execute_tools_in_parallel DONE" )
488536
489537
490538async def _run_tools_concurrently (
491- prepared_calls : List [Dict [str , Any ]], tool_results : List [UserMessage ]
539+ prepared_calls : List [Dict [str , Any ]],
540+ tool_results : List [UserMessage ],
541+ abort_signal : Optional [asyncio .Event ] = None ,
492542) -> AsyncGenerator [Union [UserMessage , ProgressMessage ], None ]:
493543 """Run tools grouped by concurrency safety (parallel for safe groups)."""
494544 for group in _group_tool_calls_by_concurrency (prepared_calls ):
495545 if group ["is_concurrency_safe" ]:
496546 logger .debug (
497547 f"[query] Executing { len (group ['items' ])} concurrency-safe tool(s) in parallel"
498548 )
499- async for message in _execute_tools_in_parallel (group ["items" ], tool_results ):
549+ async for message in _execute_tools_in_parallel (
550+ group ["items" ],
551+ tool_results ,
552+ abort_signal = abort_signal ,
553+ ):
500554 yield message
501555 else :
502556 logger .debug (
503557 f"[query] Executing { len (group ['items' ])} tool(s) sequentially (not concurrency safe)"
504558 )
505- async for message in _run_tools_serially (group ["items" ], tool_results ):
559+ async for message in _run_tools_serially (
560+ group ["items" ],
561+ tool_results ,
562+ abort_signal = abort_signal ,
563+ ):
506564 yield message
507565
508566
509567async def _run_tools_serially (
510- prepared_calls : List [Dict [str , Any ]], tool_results : List [UserMessage ]
568+ prepared_calls : List [Dict [str , Any ]],
569+ tool_results : List [UserMessage ],
570+ abort_signal : Optional [asyncio .Event ] = None ,
511571) -> AsyncGenerator [Union [UserMessage , ProgressMessage ], None ]:
512572 """Run all tools sequentially (helper for clarity)."""
513- async for message in _execute_tools_sequentially (prepared_calls , tool_results ):
573+ async for message in _execute_tools_sequentially (
574+ prepared_calls ,
575+ tool_results ,
576+ abort_signal = abort_signal ,
577+ ):
514578 yield message
515579
516580
517581async def _run_concurrent_tool_uses (
518582 generators : List [AsyncGenerator [Union [UserMessage , ProgressMessage ], None ]],
519583 tool_names : List [str ],
520584 tool_results : List [UserMessage ],
585+ * ,
586+ abort_signal : Optional [asyncio .Event ] = None ,
521587) -> AsyncGenerator [Union [UserMessage , ProgressMessage ], None ]:
522588 """Drain multiple tool generators concurrently and stream outputs with overall timeout."""
523589 overall_timeout_sec = _resolve_concurrent_timeout_sec (tool_names )
@@ -601,8 +667,41 @@ async def _consume(
601667 logger .debug (
602668 "[query] _run_concurrent_tool_uses: waiting for queue.get(), active=%d" , active
603669 )
670+ abort_waiter : Optional [asyncio .Task [bool ]] = None
671+ queue_waiter : Optional [asyncio .Task [Optional [Union [UserMessage , ProgressMessage ]]]] = None
604672 try :
605- message = await _queue_get_with_timeout (queue , overall_timeout_sec )
673+ queue_waiter = asyncio .create_task (
674+ _queue_get_with_timeout (queue , overall_timeout_sec )
675+ )
676+ if abort_signal is not None :
677+ abort_waiter = asyncio .create_task (abort_signal .wait ())
678+ done , pending = await asyncio .wait (
679+ {queue_waiter , abort_waiter },
680+ return_when = asyncio .FIRST_COMPLETED ,
681+ )
682+ for pending_task in pending :
683+ pending_task .cancel ()
684+ try :
685+ await pending_task
686+ except asyncio .CancelledError :
687+ pass
688+
689+ if abort_waiter in done and abort_signal .is_set ():
690+ logger .info (
691+ "[query] Abort signal set; cancelling %d concurrent tool task(s)" ,
692+ len (tasks ),
693+ )
694+ for task in tasks :
695+ if not task .done ():
696+ task .cancel ()
697+ break
698+
699+ if queue_waiter not in done :
700+ # Defensive: should not happen, but keep behavior deterministic.
701+ continue
702+ message = queue_waiter .result ()
703+ else :
704+ message = await queue_waiter
606705 except asyncio .TimeoutError :
607706 logger .error (
608707 "[query] Concurrent tool execution timed out waiting for messages"
@@ -612,6 +711,19 @@ async def _consume(
612711 if not task .done ():
613712 task .cancel ()
614713 raise
714+ finally :
715+ if abort_waiter is not None and not abort_waiter .done ():
716+ abort_waiter .cancel ()
717+ try :
718+ await abort_waiter
719+ except asyncio .CancelledError :
720+ pass
721+ if queue_waiter is not None and not queue_waiter .done ():
722+ queue_waiter .cancel ()
723+ try :
724+ await queue_waiter
725+ except asyncio .CancelledError :
726+ pass
615727
616728 logger .debug (
617729 "[query] _run_concurrent_tool_uses: got message type=%s, active=%d" ,
0 commit comments