Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@ async def _single_replica_deserialize_and_broadcast(
deserialization_elapsed_s,
)
logging.info(
'Finished primary replica deserialization in %.2f',
'Finished primary replica deserialization in %.2f seconds',
deserialization_elapsed_s,
)
else:
Expand Down Expand Up @@ -1379,7 +1379,7 @@ def create_zeros(shape_dtype_tup):
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/read/broadcast_duration_secs', broadcast_elapsed_s
)
logging.info('Finished broadcasting in %.2f', broadcast_elapsed_s)
logging.info('Finished broadcasting in %.2f seconds', broadcast_elapsed_s)

return shared_state

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,16 @@ def all_steps(self, read: bool = False) -> Sequence[int]:

@override
def latest_step(self) -> int | None:
return self._local_manager.latest_step()
self._p2p.sync_registry_if_stale()

step = self._p2p.get_latest_complete_step()
logging.info('P2P latest_step=%s', step)

if step is None and self._persistent_manager:
step = self._persistent_manager.latest_step()
logging.info('Persistent latest_step=%s', step)

return step

@override
def best_step(self) -> int | None:
Expand Down Expand Up @@ -413,8 +422,9 @@ def restore(
use_persistent = True

if step is None:
logging.warning('No restore step found in local storage or P2P registry.')
return None
raise FileNotFoundError(
'No steps found in either local/persistent storage or P2P registry.'
)

logging.info('Targeting restore step: %d', step)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,38 @@

def _create_persistent_handler(
mp_options: checkpoint_manager.MultiprocessingOptions,
replica_axis_index: int,
is_single_slice: bool,
) -> ocp.PyTreeCheckpointHandler:
"""Creates a PyTreeCheckpointHandler for persistent storage.

Args:
mp_options: Multiprocessing options for the checkpoint handler.
replica_axis_index: The index of the replica axis in the mesh.
is_single_slice: Whether the mesh is single-slice.

Returns:
A PyTreeCheckpointHandler configured for persistent storage.
"""
registry = type_handler_registry.create_type_handler_registry((
jax.Array,
type_handlers.ArrayHandler(
primary_host=mp_options.primary_host,
replica_id=_PRIMARY_REPLICA_ID,
use_replica_parallel=False,
handler = type_handlers.SingleReplicaArrayHandler(
replica_axis_index=replica_axis_index,
broadcast_memory_limit_bytes=1024 * 1024 * 1000,
primary_host=mp_options.primary_host,
replica_id=_PRIMARY_REPLICA_ID,
use_replica_parallel=False,
)
if is_single_slice:
handler = type_handlers.ArrayHandler(
primary_host=mp_options.primary_host,
replica_id=_PRIMARY_REPLICA_ID,
use_replica_parallel=False,
)
registry = type_handler_registry.create_type_handler_registry(
(
jax.Array,
handler,
),
))
)
return ocp.PyTreeCheckpointHandler(
use_ocdbt=True,
use_zarr3=True,
Expand Down Expand Up @@ -84,28 +99,10 @@ def __init__(
self._global_mesh,
replica_axis_index=self._replica_axis_index,
)
self._in_primary_slice = multislice.in_replica(
self._process_index,
global_mesh,
replica_axis_index=self._replica_axis_index,
replica_id=_PRIMARY_REPLICA_ID,
)

replica_devices = multislice.replica_devices(
self._global_mesh,
replica_axis_index=self._replica_axis_index,
replica_id=self._replica_id,
)
primary_host = multislice.primary_process_in_replica(
self._global_mesh,
replica_axis_index=self._replica_axis_index,
replica_id=self._replica_id,
)
active_processes = multihost.unique_processes_from_devices(replica_devices)
mp_options = checkpoint_manager.MultiprocessingOptions(
primary_host=primary_host,
active_processes=active_processes,
barrier_sync_key_prefix=f'persistent_fallback_{self._replica_id}',
primary_host=0,
active_processes=None,
barrier_sync_key_prefix='persistent_fallback',
)

internal_options = checkpoint_manager.CheckpointManagerOptions(
Expand All @@ -117,7 +114,16 @@ def __init__(
enable_async_checkpointing=True,
)

item_handlers = dict(state=_create_persistent_handler(mp_options))
item_handlers = dict(
state=_create_persistent_handler(
mp_options,
self._replica_axis_index,
multislice.replica_count(
self._global_mesh, replica_axis_index=self._replica_axis_index
)
== 1,
)
)
if utils.pygrain() is not None:
item_handlers['data_iter'] = utils.pygrain().PyGrainCheckpointHandler()

Expand All @@ -141,9 +147,7 @@ def save(
*,
force: bool = False,
) -> bool:
if self._in_primary_slice:
return self._manager.save(step, args=args, force=force)
return True
return self._manager.save(step, args=args, force=force)

def restore(
self,
Expand All @@ -166,14 +170,31 @@ def restore(
self._replica_id,
)
abstract_state = args.state
if isinstance(args.state, args_lib.PyTreeRestore):
abstract_state = args.state.item

def _get_sr_restore_args(x):
if (
multislice.replica_count(
self._global_mesh, replica_axis_index=self._replica_axis_index
)
> 1
and isinstance(x, jax.ShapeDtypeStruct)
and isinstance(x.sharding, jax.sharding.NamedSharding)
):
return type_handlers.SingleReplicaArrayRestoreArgs(
sharding=x.sharding,
global_shape=x.shape,
dtype=x.dtype,
)
else:
return checkpoint_utils.construct_restore_args(x)

restore_args_tree = jax.tree.map(_get_sr_restore_args, abstract_state)

sharding_tree = jax.tree.map(lambda x: x.sharding, abstract_state)
# TODO(exlin): Enable SingleReplicaRestore.
restore_args_obj = args_lib.PyTreeRestore(
item=abstract_state,
restore_args=checkpoint_utils.construct_restore_args(
abstract_state, sharding_tree
),
restore_args=restore_args_tree,
)
restore_kwargs = {'state': restore_args_obj}
if constants.DATA_ITER_KEY in args:
Expand All @@ -183,8 +204,7 @@ def restore(
)

def delete(self, step: int):
if self._in_primary_slice:
self._manager.delete(step)
self._manager.delete(step)

def wait_until_finished(self):
self._manager.wait_until_finished()
Expand Down
Loading
Loading