fix: restore rollout state after checkpoint update failures#6510
fix: restore rollout state after checkpoint update failures#6510athreesh wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a robust cleanup mechanism in CheckpointEngineManager.update_weights using a try...finally block to ensure that KV cache and generation states are properly restored even if weight updates or process group builds fail. It also adds comprehensive unit tests to verify these failure paths. The reviewer feedback recommends refactoring blocking ray.get() calls inside async methods to non-blocking asyncio.gather calls to prevent event loop degradation. This async transition requires making finalize_process_group asynchronous, awaiting it in the cleanup phase, and updating the unit test mocks to return awaitable futures instead of raw values.
| def finalize_process_group(self, rollout: RayWorkerGroup): | ||
| """Finalize checkpoint engines on trainer and rollout workers.""" | ||
| trainer = self.trainer | ||
| ray.get( | ||
| trainer.execute_checkpoint_engine(["finalize"] * trainer.world_size) | ||
| + rollout.execute_checkpoint_engine(["finalize"] * rollout.world_size) | ||
| ) |
There was a problem hiding this comment.
Avoid using ray.get() inside async methods of Ray actors (or helper methods called within them) as it blocks the event loop, which can lead to performance degradation or deadlocks. Instead, make finalize_process_group an asynchronous method and use await asyncio.gather to parallelize the remote calls asynchronously.
| def finalize_process_group(self, rollout: RayWorkerGroup): | |
| """Finalize checkpoint engines on trainer and rollout workers.""" | |
| trainer = self.trainer | |
| ray.get( | |
| trainer.execute_checkpoint_engine(["finalize"] * trainer.world_size) | |
| + rollout.execute_checkpoint_engine(["finalize"] * rollout.world_size) | |
| ) | |
| async def finalize_process_group(self, rollout: RayWorkerGroup): | |
| """Finalize checkpoint engines on trainer and rollout workers.""" | |
| trainer = self.trainer | |
| await asyncio.gather( | |
| *trainer.execute_checkpoint_engine(["finalize"] * trainer.world_size), | |
| *rollout.execute_checkpoint_engine(["finalize"] * rollout.world_size) | |
| ) |
References
- Avoid using
ray.get()insideasyncmethods of Ray actors. This blocks the event loop, which can lead to performance degradation or deadlocks, and typically triggers warnings in Ray. Instead, useawaiton remote calls orasyncio.gatherto parallelize multiple remote calls asynchronously.
| ray.get( | ||
| trainer.update_weights(global_steps=global_steps, mode=self.backend) | ||
| + rollout.update_weights(global_steps=global_steps) | ||
| ) |
There was a problem hiding this comment.
Avoid using ray.get() inside async methods of Ray actors as it blocks the event loop, which can lead to performance degradation or deadlocks. Instead, use await asyncio.gather to parallelize the remote calls asynchronously.
| ray.get( | |
| trainer.update_weights(global_steps=global_steps, mode=self.backend) | |
| + rollout.update_weights(global_steps=global_steps) | |
| ) | |
| await asyncio.gather( | |
| *trainer.update_weights(global_steps=global_steps, mode=self.backend), | |
| *rollout.update_weights(global_steps=global_steps) | |
| ) |
References
- Avoid using
ray.get()insideasyncmethods of Ray actors. This blocks the event loop, which can lead to performance degradation or deadlocks, and typically triggers warnings in Ray. Instead, useawaiton remote calls orasyncio.gatherto parallelize multiple remote calls asynchronously.
| if finalize_checkpoint_engines and rollout is not None: | ||
| try: | ||
| self.finalize_process_group(rollout) | ||
| except Exception as exc: | ||
| self._record_cleanup_error(cleanup_errors, "finalize checkpoint engines", exc) |
There was a problem hiding this comment.
Since finalize_process_group is now an asynchronous method, it must be awaited to ensure the checkpoint engines are properly finalized.
| if finalize_checkpoint_engines and rollout is not None: | |
| try: | |
| self.finalize_process_group(rollout) | |
| except Exception as exc: | |
| self._record_cleanup_error(cleanup_errors, "finalize checkpoint engines", exc) | |
| if finalize_checkpoint_engines and rollout is not None: | |
| try: | |
| await self.finalize_process_group(rollout) | |
| except Exception as exc: | |
| self._record_cleanup_error(cleanup_errors, "finalize checkpoint engines", exc) |
| import pytest | ||
|
|
||
| from verl.checkpoint_engine import base as checkpoint_base | ||
| from verl.checkpoint_engine.base import CheckpointEngineManager |
There was a problem hiding this comment.
Import asyncio to support creating futures for the mocked worker group methods.
| import pytest | |
| from verl.checkpoint_engine import base as checkpoint_base | |
| from verl.checkpoint_engine.base import CheckpointEngineManager | |
| import asyncio | |
| import pytest | |
| from verl.checkpoint_engine import base as checkpoint_base | |
| from verl.checkpoint_engine.base import CheckpointEngineManager |
| def update_weights(self, **kwargs): | ||
| self.events.append(f"{self.name}.update_weights") | ||
| if self.fail_update: | ||
| return [RuntimeError(f"{self.name} update failed")] | ||
| return [f"{self.name}.update_weights.done"] | ||
|
|
||
| def execute_checkpoint_engine(self, methods): | ||
| assert methods == ["finalize"] | ||
| self.events.append(f"{self.name}.finalize") | ||
| if self.fail_finalize: | ||
| return [RuntimeError(f"{self.name} finalize failed")] | ||
| return [f"{self.name}.finalize.done"] |
There was a problem hiding this comment.
Since the weight update and finalization methods now use await asyncio.gather instead of ray.get, the mocked worker group methods must return awaitable objects (such as asyncio.Future) rather than raw strings or exceptions to avoid TypeError.
def update_weights(self, **kwargs):
self.events.append(f"{self.name}.update_weights")
fut = asyncio.Future()
if self.fail_update:
fut.set_exception(RuntimeError(f"{self.name} update failed"))
else:
fut.set_result(f"{self.name}.update_weights.done")
return [fut]
def execute_checkpoint_engine(self, methods):
assert methods == ["finalize"]
self.events.append(f"{self.name}.finalize")
fut = asyncio.Future()
if self.fail_finalize:
fut.set_exception(RuntimeError(f"{self.name} finalize failed"))
else:
fut.set_result(f"{self.name}.finalize.done")
return [fut]|
Checkpoint update failures is an unrecoverable failure and should restart job. |
Summary
Tests