Skip to content

fix: restore rollout state after checkpoint update failures#6510

Open
athreesh wants to merge 1 commit into
verl-project:mainfrom
athreesh:athreesh/checkpoint-engine-update-cleanup
Open

fix: restore rollout state after checkpoint update failures#6510
athreesh wants to merge 1 commit into
verl-project:mainfrom
athreesh:athreesh/checkpoint-engine-update-cleanup

Conversation

@athreesh
Copy link
Copy Markdown
Contributor

Summary

  • finalize checkpoint engines from CheckpointEngineManager cleanup even when process-group build or weight transfer fails
  • always attempt to restore KV cache and resume generation after non-naive checkpoint update interruptions
  • add a CPU unit test covering build failure, transfer failure, cleanup failure, and primary-error preservation

Tests

  • python3 -m py_compile verl/checkpoint_engine/base.py tests/checkpoint_engine/test_manager_cleanup_on_cpu.py
  • PYTHONPATH=/home/ubuntu/verl /home/ubuntu/checkpoint-engine/.venv/bin/python -m pytest tests/checkpoint_engine/test_manager_cleanup_on_cpu.py -q
  • uv run --python /usr/bin/python3.10 --no-project --with ruff ruff check verl/checkpoint_engine/base.py tests/checkpoint_engine/test_manager_cleanup_on_cpu.py
  • git diff --check

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a robust cleanup mechanism in CheckpointEngineManager.update_weights using a try...finally block to ensure that KV cache and generation states are properly restored even if weight updates or process group builds fail. It also adds comprehensive unit tests to verify these failure paths. The reviewer feedback recommends refactoring blocking ray.get() calls inside async methods to non-blocking asyncio.gather calls to prevent event loop degradation. This async transition requires making finalize_process_group asynchronous, awaiting it in the cleanup phase, and updating the unit test mocks to return awaitable futures instead of raw values.

Comment on lines +395 to +401
def finalize_process_group(self, rollout: RayWorkerGroup):
"""Finalize checkpoint engines on trainer and rollout workers."""
trainer = self.trainer
ray.get(
trainer.execute_checkpoint_engine(["finalize"] * trainer.world_size)
+ rollout.execute_checkpoint_engine(["finalize"] * rollout.world_size)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Avoid using ray.get() inside async methods of Ray actors (or helper methods called within them) as it blocks the event loop, which can lead to performance degradation or deadlocks. Instead, make finalize_process_group an asynchronous method and use await asyncio.gather to parallelize the remote calls asynchronously.

Suggested change
def finalize_process_group(self, rollout: RayWorkerGroup):
"""Finalize checkpoint engines on trainer and rollout workers."""
trainer = self.trainer
ray.get(
trainer.execute_checkpoint_engine(["finalize"] * trainer.world_size)
+ rollout.execute_checkpoint_engine(["finalize"] * rollout.world_size)
)
async def finalize_process_group(self, rollout: RayWorkerGroup):
"""Finalize checkpoint engines on trainer and rollout workers."""
trainer = self.trainer
await asyncio.gather(
*trainer.execute_checkpoint_engine(["finalize"] * trainer.world_size),
*rollout.execute_checkpoint_engine(["finalize"] * rollout.world_size)
)
References
  1. Avoid using ray.get() inside async methods of Ray actors. This blocks the event loop, which can lead to performance degradation or deadlocks, and typically triggers warnings in Ray. Instead, use await on remote calls or asyncio.gather to parallelize multiple remote calls asynchronously.

Comment on lines +497 to +500
ray.get(
trainer.update_weights(global_steps=global_steps, mode=self.backend)
+ rollout.update_weights(global_steps=global_steps)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Avoid using ray.get() inside async methods of Ray actors as it blocks the event loop, which can lead to performance degradation or deadlocks. Instead, use await asyncio.gather to parallelize the remote calls asynchronously.

Suggested change
ray.get(
trainer.update_weights(global_steps=global_steps, mode=self.backend)
+ rollout.update_weights(global_steps=global_steps)
)
await asyncio.gather(
*trainer.update_weights(global_steps=global_steps, mode=self.backend),
*rollout.update_weights(global_steps=global_steps)
)
References
  1. Avoid using ray.get() inside async methods of Ray actors. This blocks the event loop, which can lead to performance degradation or deadlocks, and typically triggers warnings in Ray. Instead, use await on remote calls or asyncio.gather to parallelize multiple remote calls asynchronously.

Comment on lines +526 to +530
if finalize_checkpoint_engines and rollout is not None:
try:
self.finalize_process_group(rollout)
except Exception as exc:
self._record_cleanup_error(cleanup_errors, "finalize checkpoint engines", exc)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Since finalize_process_group is now an asynchronous method, it must be awaited to ensure the checkpoint engines are properly finalized.

Suggested change
if finalize_checkpoint_engines and rollout is not None:
try:
self.finalize_process_group(rollout)
except Exception as exc:
self._record_cleanup_error(cleanup_errors, "finalize checkpoint engines", exc)
if finalize_checkpoint_engines and rollout is not None:
try:
await self.finalize_process_group(rollout)
except Exception as exc:
self._record_cleanup_error(cleanup_errors, "finalize checkpoint engines", exc)

Comment on lines +14 to +17
import pytest

from verl.checkpoint_engine import base as checkpoint_base
from verl.checkpoint_engine.base import CheckpointEngineManager
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Import asyncio to support creating futures for the mocked worker group methods.

Suggested change
import pytest
from verl.checkpoint_engine import base as checkpoint_base
from verl.checkpoint_engine.base import CheckpointEngineManager
import asyncio
import pytest
from verl.checkpoint_engine import base as checkpoint_base
from verl.checkpoint_engine.base import CheckpointEngineManager

Comment on lines +54 to +65
def update_weights(self, **kwargs):
self.events.append(f"{self.name}.update_weights")
if self.fail_update:
return [RuntimeError(f"{self.name} update failed")]
return [f"{self.name}.update_weights.done"]

def execute_checkpoint_engine(self, methods):
assert methods == ["finalize"]
self.events.append(f"{self.name}.finalize")
if self.fail_finalize:
return [RuntimeError(f"{self.name} finalize failed")]
return [f"{self.name}.finalize.done"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Since the weight update and finalization methods now use await asyncio.gather instead of ray.get, the mocked worker group methods must return awaitable objects (such as asyncio.Future) rather than raw strings or exceptions to avoid TypeError.

    def update_weights(self, **kwargs):
        self.events.append(f"{self.name}.update_weights")
        fut = asyncio.Future()
        if self.fail_update:
            fut.set_exception(RuntimeError(f"{self.name} update failed"))
        else:
            fut.set_result(f"{self.name}.update_weights.done")
        return [fut]

    def execute_checkpoint_engine(self, methods):
        assert methods == ["finalize"]
        self.events.append(f"{self.name}.finalize")
        fut = asyncio.Future()
        if self.fail_finalize:
            fut.set_exception(RuntimeError(f"{self.name} finalize failed"))
        else:
            fut.set_result(f"{self.name}.finalize.done")
        return [fut]

@wuxibin89
Copy link
Copy Markdown
Collaborator

Checkpoint update failures is an unrecoverable failure and should restart job.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants