Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- `uvloop` dependency for improved event loop performance
- #v1 Add `use_load_and_broadcast` option.

### Removed

Expand All @@ -22,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- #v1 Make most V1 public concrete classes final.
- Allow restore+broadcast logic to not require a single-replica sharding
parameter, which is always constructed as a sharding over replica-local
devices anyway.
- Refactor `CheckpointLayout` splitting `load()` into `load_pytree()` and
`load_checkpointables()` each with their own dedicated loading logic
- Refactor v0 Pytree validation and metadata resolution and add `OrbaxV0Layout`
Expand Down
6 changes: 4 additions & 2 deletions checkpoint/orbax/checkpoint/_src/multihost/multislice.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def process_replica_id(
*,
replica_axis_index: int = 0,
) -> int:
"""Returns the slice id that the process_index belongs to."""
"""Returns the replica id that the process_index belongs to."""

for replica_id in range(
replica_count(global_mesh, replica_axis_index=replica_axis_index)
):
Expand All @@ -64,6 +65,7 @@ def replica_devices(
replica_id: int = 0,
replica_axis_index: int = 0,
) -> np.ndarray:
"""Returns devices for the replica with the given ID."""
return np.take(
global_mesh.devices,
replica_id,
Expand All @@ -83,7 +85,7 @@ def replica_count(
def local_replica_devices(
global_mesh: jax.sharding.Mesh, *, replica_axis_index: int = 0
) -> np.ndarray:
"""Get devices in the host-local slice."""
"""Get devices for the replica that the current process is in."""
for replica_id in range(
replica_count(global_mesh, replica_axis_index=replica_axis_index)
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1452,7 +1452,7 @@ def __init__(
replica_axis_index: Defines the axis of the global mesh along which
replicas are defined. E.g. all devices in
global_mesh.devices[replica_axis_index] are part of the same replica.
primary_replica_id: The id of the replica hosts that is used to load and
primary_replica_id: The id of the replica that is used to load and
broadcast the checkpoint.
broadcast_memory_limit_bytes: Specifies the memory size (in bytes) used
for broadcasting data.
Expand All @@ -1469,6 +1469,23 @@ def __init__(
self.broadcast_memory_limit_bytes = broadcast_memory_limit_bytes
self.broadcast_memory_scaling_factor = broadcast_memory_scaling_factor

def _construct_single_replica_sharding(
self, sharding: jax.sharding.Sharding
) -> jax.sharding.Sharding:
"""Constructs a single replica sharding."""
assert isinstance(sharding, jax.sharding.NamedSharding)
local_replica_devices = multislice.local_replica_devices(
sharding.mesh, replica_axis_index=self.replica_axis_index
)
local_replica_devices = np.expand_dims(
local_replica_devices, axis=self.replica_axis_index
)
replica_mesh = jax.sharding.Mesh(
local_replica_devices,
sharding.mesh.axis_names,
)
return jax.sharding.NamedSharding(replica_mesh, sharding.spec)

async def deserialize(
self,
infos: Sequence[types.ParamInfo],
Expand Down Expand Up @@ -1500,11 +1517,18 @@ async def deserialize(
f' {type(arg)}.'
)
if arg.sharding is None:
raise ValueError('Must provide `sharding`.')
if arg.single_replica_sharding is None:
raise ValueError('Must provide `single_replica_sharding`.')
raise ValueError(
'Must provide `sharding` to restore with'
' `SingleReplicaArrayHandler`.'
)

single_replica_shardings = [arg.single_replica_sharding for arg in args]
# arg.single_replica_sharding is not required to be passed.
single_replica_shardings = [
arg.single_replica_sharding
if arg.single_replica_sharding
else self._construct_single_replica_sharding(arg.sharding)
for arg in args
]
shardings = [arg.sharding for arg in args]

if self._dispatcher is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ class SingleReplicaArrayRestoreArgs(ArrayRestoreArgs):
on one replica hosts and do broadcasting which should significantly
improve the training start time at scale.

single_replica_sharding:
jax.sharding.NamedSharding object which describes the single replica
sharding to which current host belongs to.
single_replica_sharding: [Deprecated] This is provided for backward
compatibility only. It is not needed, as Orbax code will automatically
construct a single-replica sharding used for restoring before broadcasting.
"""

single_replica_sharding: jax.sharding.NamedSharding | None = None
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def get_pathways_array_handler(
**kwargs,
) -> type_handlers.ArrayHandler:
"""Returns the Pathways ArrayHandler with the given options."""

# If not set, use whichever dispatcher implementation is available.
checkpointing_impl = checkpointing_impl or CheckpointingImpl.from_options(
use_colocated_python=True,
Expand Down
11 changes: 10 additions & 1 deletion checkpoint/orbax/checkpoint/_src/serialization/pathways_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,16 @@


class CheckpointingImpl(enum.Enum):
"""The implementation to use for Pathways checkpointing."""
"""The implementation to use for Pathways checkpointing.

These implementations include:
- Colocated Python
- Remote Python
- Persistence Array Handler
- No Dispatcher

"No dispatcher" means that Pathways will not be used.
"""

NO_DISPATCHER = enum.auto()
COLOCATED_PYTHON = enum.auto()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,47 @@ class Loading:
If True, restoration allows silent truncating/padding of arrays if the
stored array shape does not match the target shape. Otherwise, raises an
error.
raise_array_data_missing_error:
If True, raises an error if array data is missing. Otherwise allows
returning zeros from an array range that was not necessarily written to.
use_load_and_broadcast: Whether to use load-and-broadcast for multi-replica
loading. This is useful when the model has multiple replicas across
different sets of devices (commonly across multiple TPU slices, but also
applies to data-parallel model replicas within a single slice). Array
shardings must be structured so that the mesh has a dimension on which
all model weights are replicated. The checkpoint will then be loaded only
on the hosts and devices taken from replica `primary_replica_id` along the
`replica_axis_index` dimension. It will then be broadcast to all other
replicas.
"""

@dataclasses.dataclass(frozen=True, kw_only=True)
class LoadAndBroadcastOptions:
"""Used to configure load-and-broadcast behavior in multi-replica loading.

replica_axis_index: Defines the axis of the global mesh along which
replicas are defined. E.g. all devices in
global_mesh.devices[replica_axis_index] are part of the same replica.
primary_replica_id: The id of the replica that is used to load and
broadcast the checkpoint.
broadcast_memory_limit_bytes: Specifies the memory size (in bytes) used
for broadcasting data.
broadcast_memory_scaling_factor: Specifies the fraction of available
memory to use for broadcasting data.
"""

replica_axis_index: int | None = 0
primary_replica_id: int | None = 0
broadcast_memory_limit_bytes: int | None = None
broadcast_memory_scaling_factor: float | None = 0.75

concurrent_bytes: int | None = None
enable_padding_and_truncation: bool = False
raise_array_data_missing_error: bool = True
use_load_and_broadcast: bool = False
load_and_broadcast_options: LoadAndBroadcastOptions = dataclasses.field(
default_factory=LoadAndBroadcastOptions
)

saving: Saving = dataclasses.field(default_factory=Saving)
loading: Loading = dataclasses.field(default_factory=Loading)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.serialization import protocol_utils
from orbax.checkpoint.experimental.v1._src.serialization import registration
from orbax.checkpoint.experimental.v1._src.serialization import types


Expand Down Expand Up @@ -80,22 +81,7 @@ def _create_v0_array_handler(
context: context_lib.Context,
) -> type_handlers_v0.ArrayHandler:
"""Creates a V0 array handler from a V1 context."""

saving_options = context.array_options.saving
primary_host = context.multiprocessing_options.primary_host
array_handler = type_handlers_v0.ArrayHandler(
primary_host=primary_host,
replica_id=None if primary_host is None else 0,
use_replica_parallel=saving_options.use_replica_parallel,
min_slice_bytes_for_replica_parallel=saving_options.min_slice_bytes_for_replica_parallel,
max_replicas_for_replica_parallel=saving_options.max_replicas_for_replica_parallel,
enable_replica_parallel_separate_folder=saving_options.enable_replica_parallel_separate_folder,
enable_write_sharding_file=saving_options.enable_write_sharding_file,
array_metadata_store=saving_options.array_metadata_store,
)


return array_handler
return registration.get_array_handler(context)


def _create_v0_saving_paraminfo(
Expand Down Expand Up @@ -184,12 +170,17 @@ def _create_v0_restorearg(
context: context_lib.Context,
) -> type_handlers_v0.ArrayRestoreArgs:
"""Creates a V0 `ArrayRestoreArgs` from V1 params."""
restore_arg_cls = (
type_handlers_v0.SingleReplicaArrayRestoreArgs
if context.array_options.loading.use_load_and_broadcast
else type_handlers_v0.ArrayRestoreArgs
)
value = param.value
if value is None or isinstance(value, type):
return type_handlers_v0.ArrayRestoreArgs(restore_type=jax.Array)
return restore_arg_cls(restore_type=jax.Array)
elif protocol_utils.is_subclass_protocol(value, AbstractShardedArray):
value = typing.cast(AbstractShardedArray, value)
return type_handlers_v0.ArrayRestoreArgs(
return restore_arg_cls(
restore_type=jax.Array,
dtype=value.dtype,
sharding=value.sharding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ class NumpyMetadata(AbstractArray):

def _create_v0_numpy_handler() -> type_handlers_v0.NumpyHandler:
"""Creates a V0 `NumpyHandler`."""
numpy_handler = type_handlers_v0.NumpyHandler()
return numpy_handler
return registration.get_numpy_handler()


def _create_v0_saving_paraminfo(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines helpers for creating array `v0.TypeHandler`.

Functions should return a `TypeHandler` that can be wrapped into a
:py:class:`LeafHandler`. It should return the appropriate handler based on
global settings and runtime (Pathways vs. mcJAX).

This structure also helps prevent users from including Pathways dependencies in
their
binaries when they are not running on Pathways. The Pathways imports are
deferred until `is_pathways_backend()` returns True.

Pathways dependencies should not be added to this file.
"""

from orbax.checkpoint._src.serialization import jax_array_handlers
from orbax.checkpoint._src.serialization import pathways_handler_registry
from orbax.checkpoint._src.serialization import pathways_types
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.synchronization import multihost

_PATHWAYS_IMPORT_ERROR_MSG = """
Failed to import Pathways dependencies. Please ensure you have
linked the Pathways dependencies required by Orbax to your binary.
These are found at
//orbax/checkpoint/experimental/v1:pathways_support.
Please note that such dependencies are not linked automatically
because Pathways has a lot of dependencies, which non-Pathways
users wish to avoid linking.
"""


def resolve_pathways_checkpointing_impl(
context: context_lib.Context,
) -> pathways_types.CheckpointingImpl:
"""Returns the Pathways checkpointing implementation."""
try:
# pylint: disable=g-import-not-at-top
# pytype: disable=import-error
from .learning.deepmind.jax.ocean.remote_python import rp
# pytype: enable=import-error
# pylint: enable=g-import-not-at-top
except ImportError as e:
raise ImportError(_PATHWAYS_IMPORT_ERROR_MSG) from e
checkpointing_impl = context.pathways_options.checkpointing_impl
return checkpointing_impl or pathways_types.CheckpointingImpl.from_options(
use_colocated_python=False, # Not enabled unless explicitly requested.
use_remote_python=rp.available(),
use_persistence_array_handler=True, # Only used as a fallback.
)


def get_array_handler(
context: context_lib.Context,
) -> type_handlers.ArrayHandler:
"""Returns the TypeHandler for JAX arrays (pytree leaves)."""
saving_options = context.array_options.saving
loading_options = context.array_options.loading
primary_host = context.multiprocessing_options.primary_host
common_kwargs = dict(
primary_host=primary_host,
replica_id=None if primary_host is None else 0,
use_replica_parallel=saving_options.use_replica_parallel,
min_slice_bytes_for_replica_parallel=saving_options.min_slice_bytes_for_replica_parallel,
max_replicas_for_replica_parallel=saving_options.max_replicas_for_replica_parallel,
enable_replica_parallel_separate_folder=saving_options.enable_replica_parallel_separate_folder,
enable_write_sharding_file=saving_options.enable_write_sharding_file,
array_metadata_store=saving_options.array_metadata_store,
)
if loading_options.use_load_and_broadcast:
load_and_broadcast_kwargs = dict(
replica_axis_index=loading_options.load_and_broadcast_options.replica_axis_index,
primary_replica_id=loading_options.load_and_broadcast_options.primary_replica_id,
broadcast_memory_limit_bytes=loading_options.load_and_broadcast_options.broadcast_memory_limit_bytes,
broadcast_memory_scaling_factor=loading_options.load_and_broadcast_options.broadcast_memory_scaling_factor,
)
else:
load_and_broadcast_kwargs = dict()

if multihost.is_pathways_backend():
checkpointing_impl = resolve_pathways_checkpointing_impl(context)
return pathways_handler_registry.get_pathways_array_handler(
use_single_replica_array_handler=loading_options.use_load_and_broadcast,
checkpointing_impl=checkpointing_impl,
**common_kwargs,
**load_and_broadcast_kwargs,
)
else:
if loading_options.use_load_and_broadcast:
return jax_array_handlers.SingleReplicaArrayHandler(
dispatcher=None,
**common_kwargs,
**load_and_broadcast_kwargs,
)
else:
return jax_array_handlers.ArrayHandler(dispatcher=None, **common_kwargs)


def get_numpy_handler() -> type_handlers.NumpyHandler:
"""Returns the TypeHandler for Numpy arrays."""
if multihost.is_pathways_backend():
return pathways_handler_registry.get_pathways_numpy_handler()
else:
return type_handlers.NumpyHandler()


def get_scalar_handler() -> type_handlers.ScalarHandler:
"""Returns the TypeHandler for scalars."""
if multihost.is_pathways_backend():
return pathways_handler_registry.get_pathways_scalar_handler()
else:
return type_handlers.ScalarHandler()
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@

def _create_v0_scalar_handler() -> type_handlers_v0.ScalarHandler:
"""Creates a V0 ScalarHandler."""
scalar_handler = type_handlers_v0.ScalarHandler()
return scalar_handler
return registration.get_scalar_handler()


def _create_v0_saving_paraminfo(
Expand Down
Loading
Loading