Skip to content

Commit 40c65d3

Browse files
committed
feat: support graceful abort and foreground-to-background subagent handoff
Add abort signal propagation through tool execution layers, allowing concurrent and sequential tool runs to be cancelled cleanly. Implement automatic handoff of interrupted foreground subagents to background tasks so they can continue running after user interruption. Fix UI event loop shutdown to cancel pending tasks before async generators, and defer background notification responses when the interactive prompt is active to avoid rendering races. Generated with Ripperdoc Co-Authored-By: Ripperdoc
1 parent ee5ac55 commit 40c65d3

File tree

9 files changed

+473
-16
lines changed

9 files changed

+473
-16
lines changed

ripperdoc/cli/ui/rich_ui/rendering.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def _summarize_subagent_progress_content(content: Any) -> str:
2727
return str(content)
2828

2929
message_payload = getattr(content, "message", None)
30+
metadata = getattr(message_payload, "metadata", None) or {}
31+
if metadata.get("hook_additional_context"):
32+
return ""
3033
body = getattr(message_payload, "content", None) if message_payload is not None else None
3134
if isinstance(body, str):
3235
return body

ripperdoc/cli/ui/rich_ui/session.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,6 +1519,22 @@ def _schedule_background_notification_response(
15191519
"""Trigger a follow-up assistant response for an idle task notification."""
15201520
if self._loop.is_closed():
15211521
return
1522+
# When the interactive prompt is active, avoid launching an automatic
1523+
# query that would race prompt_toolkit redraw and spinner rendering.
1524+
# Defer it into pending messages so it is consumed on the next user turn.
1525+
prompt_app = getattr(getattr(self, "_prompt_session", None), "app", None)
1526+
if (
1527+
not self._query_in_progress
1528+
and prompt_app is not None
1529+
and bool(getattr(prompt_app, "is_running", False))
1530+
and self.query_context is not None
1531+
):
1532+
self.query_context.pending_message_queue.enqueue_text(
1533+
agent_message,
1534+
metadata=metadata,
1535+
)
1536+
self._request_prompt_redraw()
1537+
return
15221538

15231539
async def _run() -> None:
15241540
try:
@@ -1974,7 +1990,13 @@ def _run_async(self, coro: Any) -> Any:
19741990
if threading.current_thread() is self._loop_thread:
19751991
raise RuntimeError("_run_async cannot be called from the UI event loop thread")
19761992
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
1977-
return future.result()
1993+
try:
1994+
return future.result()
1995+
except KeyboardInterrupt:
1996+
# If the caller interrupts while waiting, propagate cancellation to the
1997+
# coroutine so async generators are not left running into shutdown.
1998+
future.cancel()
1999+
raise
19782000

19792001
def run_async(self, coro: Any) -> Any:
19802002
"""Public wrapper for running coroutines on the UI event loop."""
@@ -2432,11 +2454,21 @@ def _shutdown_event_loop_thread(self) -> None:
24322454
if self._loop.is_closed():
24332455
return
24342456

2435-
async def _shutdown_asyncgens_only() -> None:
2436-
await asyncio.get_running_loop().shutdown_asyncgens()
2457+
async def _drain_tasks_and_shutdown_asyncgens() -> None:
2458+
loop = asyncio.get_running_loop()
2459+
current = asyncio.current_task(loop=loop)
2460+
pending = [
2461+
task for task in asyncio.all_tasks(loop)
2462+
if task is not current and not task.done()
2463+
]
2464+
for task in pending:
2465+
task.cancel()
2466+
if pending:
2467+
await asyncio.gather(*pending, return_exceptions=True)
2468+
await loop.shutdown_asyncgens()
24372469

24382470
try:
2439-
self._run_async(_shutdown_asyncgens_only())
2471+
self._run_async(_drain_tasks_and_shutdown_asyncgens())
24402472
except (RuntimeError, asyncio.CancelledError, concurrent.futures.TimeoutError):
24412473
pass
24422474

ripperdoc/core/query/loop.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1438,7 +1438,11 @@ async def _process_iteration_assistant_message(
14381438
return
14391439

14401440
if prepared.prepared_calls:
1441-
async for message in _run_tools_concurrently(prepared.prepared_calls, prepared.tool_results):
1441+
async for message in _run_tools_concurrently(
1442+
prepared.prepared_calls,
1443+
prepared.tool_results,
1444+
abort_signal=query_context.abort_controller,
1445+
):
14421446
yield message
14431447

14441448
_apply_skill_context_updates(prepared.tool_results, query_context, context)
@@ -1491,6 +1495,7 @@ async def _process_iteration_assistant_message(
14911495
async for message in _run_tools_concurrently(
14921496
auto_prepared.prepared_calls,
14931497
auto_prepared.tool_results,
1498+
abort_signal=query_context.abort_controller,
14941499
):
14951500
yield message
14961501

ripperdoc/core/query/tools.py

Lines changed: 122 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -456,21 +456,64 @@ def _group_tool_calls_by_concurrency(prepared_calls: List[Dict[str, Any]]) -> Li
456456

457457

458458
async def _execute_tools_sequentially(
459-
items: List[Dict[str, Any]], tool_results: List[UserMessage]
459+
items: List[Dict[str, Any]],
460+
tool_results: List[UserMessage],
461+
abort_signal: Optional[asyncio.Event] = None,
460462
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
461463
"""Run tool generators one by one."""
462464
for item in items:
463465
gen = item.get("generator")
464466
if not gen:
465467
continue
466-
async for message in gen:
468+
if abort_signal is None:
469+
async for message in gen:
470+
if isinstance(message, UserMessage):
471+
tool_results.append(message)
472+
yield message
473+
continue
474+
475+
while True:
476+
if abort_signal.is_set():
477+
await gen.aclose()
478+
return
479+
480+
next_item = asyncio.create_task(gen.__anext__())
481+
abort_waiter = asyncio.create_task(abort_signal.wait())
482+
done, pending = await asyncio.wait(
483+
{next_item, abort_waiter},
484+
return_when=asyncio.FIRST_COMPLETED,
485+
)
486+
487+
for pending_task in pending:
488+
pending_task.cancel()
489+
try:
490+
await pending_task
491+
except asyncio.CancelledError:
492+
pass
493+
494+
if abort_waiter in done and abort_signal.is_set():
495+
next_item.cancel()
496+
try:
497+
await next_item
498+
except (asyncio.CancelledError, StopAsyncIteration):
499+
pass
500+
await gen.aclose()
501+
return
502+
503+
try:
504+
message = next_item.result()
505+
except StopAsyncIteration:
506+
break
507+
467508
if isinstance(message, UserMessage):
468509
tool_results.append(message)
469510
yield message
470511

471512

472513
async def _execute_tools_in_parallel(
473-
items: List[Dict[str, Any]], tool_results: List[UserMessage]
514+
items: List[Dict[str, Any]],
515+
tool_results: List[UserMessage],
516+
abort_signal: Optional[asyncio.Event] = None,
474517
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
475518
"""Run tool generators concurrently."""
476519
logger.debug("[query] _execute_tools_in_parallel ENTER: %d items", len(items))
@@ -482,42 +525,65 @@ async def _execute_tools_in_parallel(
482525
len(generators),
483526
tool_names,
484527
)
485-
async for message in _run_concurrent_tool_uses(generators, tool_names, tool_results):
528+
async for message in _run_concurrent_tool_uses(
529+
generators,
530+
tool_names,
531+
tool_results,
532+
abort_signal=abort_signal,
533+
):
486534
yield message
487535
logger.debug("[query] _execute_tools_in_parallel DONE")
488536

489537

490538
async def _run_tools_concurrently(
491-
prepared_calls: List[Dict[str, Any]], tool_results: List[UserMessage]
539+
prepared_calls: List[Dict[str, Any]],
540+
tool_results: List[UserMessage],
541+
abort_signal: Optional[asyncio.Event] = None,
492542
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
493543
"""Run tools grouped by concurrency safety (parallel for safe groups)."""
494544
for group in _group_tool_calls_by_concurrency(prepared_calls):
495545
if group["is_concurrency_safe"]:
496546
logger.debug(
497547
f"[query] Executing {len(group['items'])} concurrency-safe tool(s) in parallel"
498548
)
499-
async for message in _execute_tools_in_parallel(group["items"], tool_results):
549+
async for message in _execute_tools_in_parallel(
550+
group["items"],
551+
tool_results,
552+
abort_signal=abort_signal,
553+
):
500554
yield message
501555
else:
502556
logger.debug(
503557
f"[query] Executing {len(group['items'])} tool(s) sequentially (not concurrency safe)"
504558
)
505-
async for message in _run_tools_serially(group["items"], tool_results):
559+
async for message in _run_tools_serially(
560+
group["items"],
561+
tool_results,
562+
abort_signal=abort_signal,
563+
):
506564
yield message
507565

508566

509567
async def _run_tools_serially(
510-
prepared_calls: List[Dict[str, Any]], tool_results: List[UserMessage]
568+
prepared_calls: List[Dict[str, Any]],
569+
tool_results: List[UserMessage],
570+
abort_signal: Optional[asyncio.Event] = None,
511571
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
512572
"""Run all tools sequentially (helper for clarity)."""
513-
async for message in _execute_tools_sequentially(prepared_calls, tool_results):
573+
async for message in _execute_tools_sequentially(
574+
prepared_calls,
575+
tool_results,
576+
abort_signal=abort_signal,
577+
):
514578
yield message
515579

516580

517581
async def _run_concurrent_tool_uses(
518582
generators: List[AsyncGenerator[Union[UserMessage, ProgressMessage], None]],
519583
tool_names: List[str],
520584
tool_results: List[UserMessage],
585+
*,
586+
abort_signal: Optional[asyncio.Event] = None,
521587
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
522588
"""Drain multiple tool generators concurrently and stream outputs with overall timeout."""
523589
overall_timeout_sec = _resolve_concurrent_timeout_sec(tool_names)
@@ -601,8 +667,41 @@ async def _consume(
601667
logger.debug(
602668
"[query] _run_concurrent_tool_uses: waiting for queue.get(), active=%d", active
603669
)
670+
abort_waiter: Optional[asyncio.Task[bool]] = None
671+
queue_waiter: Optional[asyncio.Task[Optional[Union[UserMessage, ProgressMessage]]]] = None
604672
try:
605-
message = await _queue_get_with_timeout(queue, overall_timeout_sec)
673+
queue_waiter = asyncio.create_task(
674+
_queue_get_with_timeout(queue, overall_timeout_sec)
675+
)
676+
if abort_signal is not None:
677+
abort_waiter = asyncio.create_task(abort_signal.wait())
678+
done, pending = await asyncio.wait(
679+
{queue_waiter, abort_waiter},
680+
return_when=asyncio.FIRST_COMPLETED,
681+
)
682+
for pending_task in pending:
683+
pending_task.cancel()
684+
try:
685+
await pending_task
686+
except asyncio.CancelledError:
687+
pass
688+
689+
if abort_waiter in done and abort_signal.is_set():
690+
logger.info(
691+
"[query] Abort signal set; cancelling %d concurrent tool task(s)",
692+
len(tasks),
693+
)
694+
for task in tasks:
695+
if not task.done():
696+
task.cancel()
697+
break
698+
699+
if queue_waiter not in done:
700+
# Defensive: should not happen, but keep behavior deterministic.
701+
continue
702+
message = queue_waiter.result()
703+
else:
704+
message = await queue_waiter
606705
except asyncio.TimeoutError:
607706
logger.error(
608707
"[query] Concurrent tool execution timed out waiting for messages"
@@ -612,6 +711,19 @@ async def _consume(
612711
if not task.done():
613712
task.cancel()
614713
raise
714+
finally:
715+
if abort_waiter is not None and not abort_waiter.done():
716+
abort_waiter.cancel()
717+
try:
718+
await abort_waiter
719+
except asyncio.CancelledError:
720+
pass
721+
if queue_waiter is not None and not queue_waiter.done():
722+
queue_waiter.cancel()
723+
try:
724+
await queue_waiter
725+
except asyncio.CancelledError:
726+
pass
615727

616728
logger.debug(
617729
"[query] _run_concurrent_tool_uses: got message type=%s, active=%d",

0 commit comments

Comments
 (0)