From 506710a6b0d30a85ecf5b46057d562f99bfaed9a Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Tue, 17 Feb 2026 10:27:41 -0800 Subject: [PATCH] #p2p Implement latest_step correctly PiperOrigin-RevId: 871383221 --- .../_src/serialization/jax_array_handlers.py | 4 +- .../emergency/p2p/checkpoint_manager.py | 26 +++-- .../emergency/p2p/checkpoint_manager_test.py | 40 ++++++++ .../experimental/emergency/p2p/persistent.py | 98 +++++++++++-------- .../emergency/p2p/persistent_test.py | 92 ++++++++++------- 5 files changed, 180 insertions(+), 80 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index 312bb5962..78fe945d1 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -1346,7 +1346,7 @@ async def _single_replica_deserialize_and_broadcast( deserialization_elapsed_s, ) logging.info( - 'Finished primary replica deserialization in %.2f', + 'Finished primary replica deserialization in %.2f seconds', deserialization_elapsed_s, ) else: @@ -1379,7 +1379,7 @@ def create_zeros(shape_dtype_tup): jax.monitoring.record_event_duration_secs( '/jax/checkpoint/read/broadcast_duration_secs', broadcast_elapsed_s ) - logging.info('Finished broadcasting in %.2f', broadcast_elapsed_s) + logging.info('Finished broadcasting in %.2f seconds', broadcast_elapsed_s) return shared_state diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager.py index a821a9ab3..df65f9a0f 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager.py @@ -17,7 +17,7 @@ import shutil import threading import time -from typing import Any, Iterable, Mapping, Optional, Sequence, Union, final +from typing import Any, Iterable, Sequence, final from absl import logging from etils import epath @@ -303,7 +303,20 @@ def all_steps(self, read: bool = False) -> Sequence[int]: @override def latest_step(self) -> int | None: - return self._local_manager.latest_step() + self._p2p.sync_registry_if_stale() + + # We intentionally use the step returned by P2P regardless of whether a + # newer step is available in persistent storage. This is because we assume + # P2P is more efficient overall for catching up to the latest step, even + # if persistent storage has a newer step available. + step = self._p2p.get_latest_complete_step() + logging.info('P2P latest_step=%s', step) + + if step is None and self._persistent_manager: + step = self._persistent_manager.latest_step() + logging.info('Persistent latest_step=%s', step) + + return step @override def best_step(self) -> int | None: @@ -392,7 +405,7 @@ def _restore_from_local_or_p2p( @override def restore( self, step: int | None, args: p2p_args_lib.Composite | None - ) -> Union[Any, Mapping[str, Any], p2p_args_lib.Composite, None]: + ) -> p2p_args_lib.Composite | None: if args is None: raise ValueError('The `args` parameter is required for restore.') @@ -413,8 +426,9 @@ def restore( use_persistent = True if step is None: - logging.warning('No restore step found in local storage or P2P registry.') - return None + raise FileNotFoundError( + 'No steps found in either local/persistent storage or P2P registry.' + ) logging.info('Targeting restore step: %d', step) @@ -465,7 +479,7 @@ def item_metadata(self, step: int) -> Any: return self._local_manager.item_metadata(step) @override - def metadata(self, step: Optional[int] = None) -> Any: + def metadata(self, step: int | None = None) -> Any: return self._local_manager.metadata(step) @override diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager_test.py index ece5032cc..ea4dcc930 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager_test.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager_test.py @@ -434,6 +434,46 @@ def test_restore_p2p_cleanup(self, unused_process_index, mock_rmtree): mock_rmtree.assert_called_once_with(str(p2p_restore_dir)) manager.close() + @mock.patch.object(multihost, 'process_index', return_value=0) + def test_latest_step_prefers_p2p_even_if_persistent_is_newer(self, _): + self.local_manager_instance.scan_stored_steps.return_value = (0, []) + self.mock_sync_global_data.return_value = [] + manager = p2p_cm.CheckpointManager( + self.mesh, + self.abstract_state, + self.local_dir, + persistent_directory=self.persistent_dir, + ) + # Mock P2P returning 5 + self.peer_selector_instance.get_latest_complete_step.return_value = 5 + # Mock persistent returning 10 + self.persistent_manager_instance.latest_step.return_value = 10 + + self.assertEqual(5, manager.latest_step()) + self.peer_selector_instance.get_latest_complete_step.assert_called_once() + self.persistent_manager_instance.latest_step.assert_not_called() + manager.close() + + @mock.patch.object(multihost, 'process_index', return_value=0) + def test_latest_step_falls_back_to_persistent(self, _): + self.local_manager_instance.scan_stored_steps.return_value = (0, []) + self.mock_sync_global_data.return_value = [] + manager = p2p_cm.CheckpointManager( + self.mesh, + self.abstract_state, + self.local_dir, + persistent_directory=self.persistent_dir, + ) + # Mock P2P returning None + self.peer_selector_instance.get_latest_complete_step.return_value = None + # Mock persistent returning 5 + self.persistent_manager_instance.latest_step.return_value = 5 + + self.assertEqual(5, manager.latest_step()) + self.peer_selector_instance.get_latest_complete_step.assert_called_once() + self.persistent_manager_instance.latest_step.assert_called_once() + manager.close() + if __name__ == '__main__': absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent.py index d3c5d2de4..db0b2615f 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent.py @@ -38,23 +38,38 @@ def _create_persistent_handler( mp_options: checkpoint_manager.MultiprocessingOptions, + replica_axis_index: int, + is_single_slice: bool, ) -> ocp.PyTreeCheckpointHandler: """Creates a PyTreeCheckpointHandler for persistent storage. Args: mp_options: Multiprocessing options for the checkpoint handler. + replica_axis_index: The index of the replica axis in the mesh. + is_single_slice: Whether the mesh is single-slice. Returns: A PyTreeCheckpointHandler configured for persistent storage. """ - registry = type_handler_registry.create_type_handler_registry(( - jax.Array, - type_handlers.ArrayHandler( - primary_host=mp_options.primary_host, - replica_id=_PRIMARY_REPLICA_ID, - use_replica_parallel=False, + handler = type_handlers.SingleReplicaArrayHandler( + replica_axis_index=replica_axis_index, + broadcast_memory_limit_bytes=1024 * 1024 * 1000, + primary_host=mp_options.primary_host, + replica_id=_PRIMARY_REPLICA_ID, + use_replica_parallel=False, + ) + if is_single_slice: + handler = type_handlers.ArrayHandler( + primary_host=mp_options.primary_host, + replica_id=_PRIMARY_REPLICA_ID, + use_replica_parallel=False, + ) + registry = type_handler_registry.create_type_handler_registry( + ( + jax.Array, + handler, ), - )) + ) return ocp.PyTreeCheckpointHandler( use_ocdbt=True, use_zarr3=True, @@ -84,28 +99,10 @@ def __init__( self._global_mesh, replica_axis_index=self._replica_axis_index, ) - self._in_primary_slice = multislice.in_replica( - self._process_index, - global_mesh, - replica_axis_index=self._replica_axis_index, - replica_id=_PRIMARY_REPLICA_ID, - ) - - replica_devices = multislice.replica_devices( - self._global_mesh, - replica_axis_index=self._replica_axis_index, - replica_id=self._replica_id, - ) - primary_host = multislice.primary_process_in_replica( - self._global_mesh, - replica_axis_index=self._replica_axis_index, - replica_id=self._replica_id, - ) - active_processes = multihost.unique_processes_from_devices(replica_devices) mp_options = checkpoint_manager.MultiprocessingOptions( - primary_host=primary_host, - active_processes=active_processes, - barrier_sync_key_prefix=f'persistent_fallback_{self._replica_id}', + primary_host=0, + active_processes=None, + barrier_sync_key_prefix='persistent_fallback', ) internal_options = checkpoint_manager.CheckpointManagerOptions( @@ -117,7 +114,16 @@ def __init__( enable_async_checkpointing=True, ) - item_handlers = dict(state=_create_persistent_handler(mp_options)) + item_handlers = dict( + state=_create_persistent_handler( + mp_options, + self._replica_axis_index, + multislice.replica_count( + self._global_mesh, replica_axis_index=self._replica_axis_index + ) + == 1, + ) + ) if utils.pygrain() is not None: item_handlers['data_iter'] = utils.pygrain().PyGrainCheckpointHandler() @@ -141,9 +147,7 @@ def save( *, force: bool = False, ) -> bool: - if self._in_primary_slice: - return self._manager.save(step, args=args, force=force) - return True + return self._manager.save(step, args=args, force=force) def restore( self, @@ -166,14 +170,31 @@ def restore( self._replica_id, ) abstract_state = args.state + if isinstance(args.state, args_lib.PyTreeRestore): + abstract_state = args.state.item + + def _get_sr_restore_args(x): + if ( + multislice.replica_count( + self._global_mesh, replica_axis_index=self._replica_axis_index + ) + > 1 + and isinstance(x, jax.ShapeDtypeStruct) + and isinstance(x.sharding, jax.sharding.NamedSharding) + ): + return type_handlers.SingleReplicaArrayRestoreArgs( + sharding=x.sharding, + global_shape=x.shape, + dtype=x.dtype, + ) + else: + return checkpoint_utils.construct_restore_args(x) + + restore_args_tree = jax.tree.map(_get_sr_restore_args, abstract_state) - sharding_tree = jax.tree.map(lambda x: x.sharding, abstract_state) - # TODO(exlin): Enable SingleReplicaRestore. restore_args_obj = args_lib.PyTreeRestore( item=abstract_state, - restore_args=checkpoint_utils.construct_restore_args( - abstract_state, sharding_tree - ), + restore_args=restore_args_tree, ) restore_kwargs = {'state': restore_args_obj} if constants.DATA_ITER_KEY in args: @@ -183,8 +204,7 @@ def restore( ) def delete(self, step: int): - if self._in_primary_slice: - self._manager.delete(step) + self._manager.delete(step) def wait_until_finished(self): self._manager.wait_until_finished() diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent_test.py index 267b7be4f..56dc986d7 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent_test.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest from unittest import mock from absl.testing import absltest @@ -32,6 +33,9 @@ class MockJaxClient: runtime_type = 'tpu' + def process_index(self): + return 0 + class MockDevice: @@ -40,6 +44,8 @@ def __init__(self, process_index, slice_index): self.process_index = process_index self.slice_index = slice_index self.client = MockJaxClient() + self.platform = 'tpu' + self.device_kind = 'tpu' def __repr__(self): return ( @@ -61,6 +67,11 @@ def setUp(self): return_value=self.mock_client, ) ) + self.enter_context( + mock.patch( + 'orbax.checkpoint._src.multihost.multihost.sync_global_processes' + ) + ) devices = np.array([ [MockDevice(0, 0), MockDevice(1, 0)], @@ -105,10 +116,16 @@ def _patch_process_index( return_value=in_primary_slice, ) ) + + def _mock_replica_devices(mesh, replica_axis_index, replica_id): + return np.take( + mesh.devices, indices=replica_id, axis=replica_axis_index + ).flatten() + self.enter_context( mock.patch( 'orbax.checkpoint._src.multihost.multislice.replica_devices', - return_value=self.mesh.devices.flatten(), + side_effect=_mock_replica_devices, ) ) self.enter_context( @@ -135,7 +152,6 @@ def test_init_in_primary_slice(self): manager = persistent.PersistentCheckpointManager( self.directory, self.mesh, replica_axis_index=0, options=self.options ) - self.assertTrue(manager._in_primary_slice) self.assertIsNotNone(manager._manager) manager.close() @@ -146,7 +162,6 @@ def test_init_not_in_primary_slice(self): manager = persistent.PersistentCheckpointManager( self.directory, self.mesh, replica_axis_index=0, options=self.options ) - self.assertFalse(manager._in_primary_slice) self.assertIsNotNone(manager._manager) manager.close() @@ -163,7 +178,7 @@ def test_save_in_primary_slice_saves(self): manager._manager.save.assert_called_once() manager.close() - def test_save_not_in_primary_slice_does_not_save(self): + def test_save_not_in_primary_slice_saves(self): self._patch_process_index( process_index=2, in_primary_slice=False, replica_id=1 ) @@ -175,9 +190,13 @@ def test_save_not_in_primary_slice_does_not_save(self): state=args_lib.PyTreeSave({'a': jax.device_put(1)}) ) manager.save(1, args) - manager._manager.save.assert_not_called() + manager._manager.save.assert_called_once() manager.close() + @unittest.skip( + 'Cannot create sharded jax.Array with MockDevice that is not compatible' + ' with jax.Device for batched_device_put.' + ) def test_save_and_restore(self): self._patch_process_index(process_index=0) # persistent checkpoint manager with multiprocessing only works with a @@ -186,20 +205,24 @@ def test_save_and_restore(self): devices = np.array([ [MockDevice(0, 0)], ]) - mesh = mock.Mock( - spec=jax.sharding.Mesh, - devices=devices, - axis_names=('replica', 'data'), - shape={'replica': 1, 'data': 1}, - shape_tuple=devices.shape, - size=devices.size, - ) + mesh = Mesh(devices, ('replica', 'data')) manager = persistent.PersistentCheckpointManager( self.directory, mesh, replica_axis_index=0, options=self.options ) - arr = jax.device_put(np.arange(self.mesh.size, dtype=np.int32)) - state = {'a': arr, 'b': jax.device_put(1)} + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('data') + ) + arr = jax.make_array_from_callback( + (1,), sharding, lambda idx: np.asarray([0], dtype=np.int32) + ) + sharding_scalar = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec() + ) + b = jax.make_array_from_callback( + (), sharding_scalar, lambda idx: np.asarray(1, dtype=np.int32) + ) + state = {'a': arr, 'b': b} args = p2p_args_lib.Composite(state=args_lib.PyTreeSave(state)) self.assertTrue(manager.save(1, args)) @@ -215,7 +238,8 @@ def _to_abstract(x): abstract_state = jax.tree.map(_to_abstract, state) restored = manager.restore( - 1, args=p2p_args_lib.Composite(state=abstract_state) + 1, + args=p2p_args_lib.Composite(state=abstract_state), ) restored_state = restored.state test_utils.assert_tree_equal(self, state, restored_state) @@ -226,21 +250,19 @@ def _to_abstract(x): ) def test_save_restore_with_grain_iterator(self, unused_process_index): self._patch_process_index(process_index=0) - # persistent checkpoint manager with multiprocessing only works with a - # unified storage. self.enter_context(mock.patch.object(jax, 'process_count', return_value=1)) - devices = np.array([ - [MockDevice(0, 0)], - ]) - mesh = mock.Mock( - spec=jax.sharding.Mesh, - devices=devices, - axis_names=('replica', 'data'), - shape={'replica': 1, 'data': 1}, - shape_tuple=devices.shape, - size=devices.size, + real_devices = jax.local_devices() + fake_device = MockDevice(process_index=1, slice_index=0) + fake_device.id = real_devices[0].id + 1 + all_devices = [real_devices[0], fake_device] + + self.enter_context( + mock.patch.object(jax, 'devices', return_value=all_devices) ) + devices = np.array(real_devices[:1]).reshape(1, 1) + mesh = Mesh(devices, ('replica', 'data')) + manager = persistent.PersistentCheckpointManager( self.directory, mesh, replica_axis_index=0, options=self.options ) @@ -255,7 +277,7 @@ def test_save_restore_with_grain_iterator(self, unused_process_index): for _ in range(3): next(data_iter) - arr = jax.device_put(np.arange(self.mesh.size, dtype=np.int32)) + arr = jax.device_put(np.arange(mesh.size, dtype=np.int32)) state = {'a': arr} save_args = p2p_args_lib.Composite( state=args_lib.PyTreeSave(state), @@ -272,8 +294,12 @@ def test_save_restore_with_grain_iterator(self, unused_process_index): new_data_iter = iter(new_dl) # PersistentCheckpointManager expects the state with sharding information # in args.state. + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + abstract_state = { + 'a': jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=sharding) + } restore_args = p2p_args_lib.Composite( - state=state, + state=abstract_state, data_iter=pygrain.PyGrainCheckpointRestore(new_data_iter), ) restored = manager.restore(1, args=restore_args) @@ -281,7 +307,7 @@ def test_save_restore_with_grain_iterator(self, unused_process_index): self.assertIn('state', restored) self.assertIn('data_iter', restored) test_utils.assert_tree_equal(self, state, restored['state']) - self.assertEqual(next(restored['data_iter']), 3) + self.assertEqual(next(restored['data_iter']), [3]) manager.close() def test_delete_in_primary_slice_deletes(self): @@ -294,7 +320,7 @@ def test_delete_in_primary_slice_deletes(self): manager._manager.delete.assert_called_once_with(1) manager.close() - def test_delete_not_in_primary_slice_does_not_delete(self): + def test_delete_not_in_primary_slice_deletes(self): self._patch_process_index( process_index=2, in_primary_slice=False, replica_id=1 ) @@ -303,7 +329,7 @@ def test_delete_not_in_primary_slice_does_not_delete(self): ) manager._manager = mock.MagicMock() manager.delete(1) - manager._manager.delete.assert_not_called() + manager._manager.delete.assert_called_once_with(1) manager.close() def test_wait_until_finished_calls_manager(self):