From 631bc4a21f77557c8018ba292cf556389a3daf4b Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Fri, 13 Feb 2026 13:38:53 -0800 Subject: [PATCH] #v1 Add `use_load_and_broadcast` option. PiperOrigin-RevId: 869869967 --- checkpoint/CHANGELOG.md | 4 + .../checkpoint/_src/multihost/multislice.py | 6 +- .../_src/serialization/jax_array_handlers.py | 34 ++++- .../serialization/jax_array_restore_args.py | 6 +- .../pathways_handler_registry.py | 1 + .../_src/serialization/pathways_types.py | 11 +- .../experimental/v1/_src/context/options.py | 36 +++++ .../_src/serialization/array_leaf_handler.py | 27 ++-- .../_src/serialization/numpy_leaf_handler.py | 3 +- .../v1/_src/serialization/registration.py | 126 ++++++++++++++++++ .../_src/serialization/scalar_leaf_handler.py | 3 +- .../v1/_src/testing/save_load_test_base.py | 44 +++++- checkpoint/orbax/checkpoint/test_utils.py | 32 ----- 13 files changed, 264 insertions(+), 69 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registration.py diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 5ce0e3c55..a91279357 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `uvloop` dependency for improved event loop performance +- #v1 Add `use_load_and_broadcast` option. ### Removed @@ -22,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - #v1 Make most V1 public concrete classes final. +- Allow restore+broadcast logic to not require a single-replica sharding +parameter, which is always constructed as a sharding over replica-local +devices anyway. - Refactor `CheckpointLayout` splitting `load()` into `load_pytree()` and `load_checkpointables()` each with their own dedicated loading logic - Refactor v0 Pytree validation and metadata resolution and add `OrbaxV0Layout` diff --git a/checkpoint/orbax/checkpoint/_src/multihost/multislice.py b/checkpoint/orbax/checkpoint/_src/multihost/multislice.py index 7de71073d..9fce17b8c 100644 --- a/checkpoint/orbax/checkpoint/_src/multihost/multislice.py +++ b/checkpoint/orbax/checkpoint/_src/multihost/multislice.py @@ -38,7 +38,8 @@ def process_replica_id( *, replica_axis_index: int = 0, ) -> int: - """Returns the slice id that the process_index belongs to.""" + """Returns the replica id that the process_index belongs to.""" + for replica_id in range( replica_count(global_mesh, replica_axis_index=replica_axis_index) ): @@ -64,6 +65,7 @@ def replica_devices( replica_id: int = 0, replica_axis_index: int = 0, ) -> np.ndarray: + """Returns devices for the replica with the given ID.""" return np.take( global_mesh.devices, replica_id, @@ -83,7 +85,7 @@ def replica_count( def local_replica_devices( global_mesh: jax.sharding.Mesh, *, replica_axis_index: int = 0 ) -> np.ndarray: - """Get devices in the host-local slice.""" + """Get devices for the replica that the current process is in.""" for replica_id in range( replica_count(global_mesh, replica_axis_index=replica_axis_index) ): diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index 05900fa57..d5566e4cc 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -1452,7 +1452,7 @@ def __init__( replica_axis_index: Defines the axis of the global mesh along which replicas are defined. E.g. all devices in global_mesh.devices[replica_axis_index] are part of the same replica. - primary_replica_id: The id of the replica hosts that is used to load and + primary_replica_id: The id of the replica that is used to load and broadcast the checkpoint. broadcast_memory_limit_bytes: Specifies the memory size (in bytes) used for broadcasting data. @@ -1469,6 +1469,23 @@ def __init__( self.broadcast_memory_limit_bytes = broadcast_memory_limit_bytes self.broadcast_memory_scaling_factor = broadcast_memory_scaling_factor + def _construct_single_replica_sharding( + self, sharding: jax.sharding.Sharding + ) -> jax.sharding.Sharding: + """Constructs a single replica sharding.""" + assert isinstance(sharding, jax.sharding.NamedSharding) + local_replica_devices = multislice.local_replica_devices( + sharding.mesh, replica_axis_index=self.replica_axis_index + ) + local_replica_devices = np.expand_dims( + local_replica_devices, axis=self.replica_axis_index + ) + replica_mesh = jax.sharding.Mesh( + local_replica_devices, + sharding.mesh.axis_names, + ) + return jax.sharding.NamedSharding(replica_mesh, sharding.spec) + async def deserialize( self, infos: Sequence[types.ParamInfo], @@ -1500,11 +1517,18 @@ async def deserialize( f' {type(arg)}.' ) if arg.sharding is None: - raise ValueError('Must provide `sharding`.') - if arg.single_replica_sharding is None: - raise ValueError('Must provide `single_replica_sharding`.') + raise ValueError( + 'Must provide `sharding` to restore with' + ' `SingleReplicaArrayHandler`.' + ) - single_replica_shardings = [arg.single_replica_sharding for arg in args] + # arg.single_replica_sharding is not required to be passed. + single_replica_shardings = [ + arg.single_replica_sharding + if arg.single_replica_sharding + else self._construct_single_replica_sharding(arg.sharding) + for arg in args + ] shardings = [arg.sharding for arg in args] if self._dispatcher is None: diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_restore_args.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_restore_args.py index f4dbd9678..62d71c9bc 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_restore_args.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_restore_args.py @@ -92,9 +92,9 @@ class SingleReplicaArrayRestoreArgs(ArrayRestoreArgs): on one replica hosts and do broadcasting which should significantly improve the training start time at scale. - single_replica_sharding: - jax.sharding.NamedSharding object which describes the single replica - sharding to which current host belongs to. + single_replica_sharding: [Deprecated] This is provided for backward + compatibility only. It is not needed, as Orbax code will automatically + construct a single-replica sharding used for restoring before broadcasting. """ single_replica_sharding: jax.sharding.NamedSharding | None = None diff --git a/checkpoint/orbax/checkpoint/_src/serialization/pathways_handler_registry.py b/checkpoint/orbax/checkpoint/_src/serialization/pathways_handler_registry.py index 28c538f5e..5aee5ae2f 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/pathways_handler_registry.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/pathways_handler_registry.py @@ -90,6 +90,7 @@ def get_pathways_array_handler( **kwargs, ) -> type_handlers.ArrayHandler: """Returns the Pathways ArrayHandler with the given options.""" + # If not set, use whichever dispatcher implementation is available. checkpointing_impl = checkpointing_impl or CheckpointingImpl.from_options( use_colocated_python=True, diff --git a/checkpoint/orbax/checkpoint/_src/serialization/pathways_types.py b/checkpoint/orbax/checkpoint/_src/serialization/pathways_types.py index 30fd92670..b1c31001b 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/pathways_types.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/pathways_types.py @@ -24,7 +24,16 @@ class CheckpointingImpl(enum.Enum): - """The implementation to use for Pathways checkpointing.""" + """The implementation to use for Pathways checkpointing. + + These implementations include: + - Colocated Python + - Remote Python + - Persistence Array Handler + - No Dispatcher + + "No dispatcher" means that Pathways will not be used. + """ NO_DISPATCHER = enum.auto() COLOCATED_PYTHON = enum.auto() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py index 3a5c18785..87703ce19 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py @@ -282,11 +282,47 @@ class Loading: If True, restoration allows silent truncating/padding of arrays if the stored array shape does not match the target shape. Otherwise, raises an error. + raise_array_data_missing_error: + If True, raises an error if array data is missing. Otherwise allows + returning zeros from an array range that was not necessarily written to. + use_load_and_broadcast: Whether to use load-and-broadcast for multi-replica + loading. This is useful when the model has multiple replicas across + different sets of devices (commonly across multiple TPU slices, but also + applies to data-parallel model replicas within a single slice). Array + shardings must be structured so that the mesh has a dimension on which + all model weights are replicated. The checkpoint will then be loaded only + on the hosts and devices taken from replica `primary_replica_id` along the + `replica_axis_index` dimension. It will then be broadcast to all other + replicas. """ + @dataclasses.dataclass(frozen=True, kw_only=True) + class LoadAndBroadcastOptions: + """Used to configure load-and-broadcast behavior in multi-replica loading. + + replica_axis_index: Defines the axis of the global mesh along which + replicas are defined. E.g. all devices in + global_mesh.devices[replica_axis_index] are part of the same replica. + primary_replica_id: The id of the replica that is used to load and + broadcast the checkpoint. + broadcast_memory_limit_bytes: Specifies the memory size (in bytes) used + for broadcasting data. + broadcast_memory_scaling_factor: Specifies the fraction of available + memory to use for broadcasting data. + """ + + replica_axis_index: int | None = 0 + primary_replica_id: int | None = 0 + broadcast_memory_limit_bytes: int | None = None + broadcast_memory_scaling_factor: float | None = 0.75 + concurrent_bytes: int | None = None enable_padding_and_truncation: bool = False raise_array_data_missing_error: bool = True + use_load_and_broadcast: bool = False + load_and_broadcast_options: LoadAndBroadcastOptions = dataclasses.field( + default_factory=LoadAndBroadcastOptions + ) saving: Saving = dataclasses.field(default_factory=Saving) loading: Loading = dataclasses.field(default_factory=Loading) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py index c4c2992f7..d1557227f 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py @@ -33,6 +33,7 @@ from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0 from orbax.checkpoint.experimental.v1._src.context import context as context_lib from orbax.checkpoint.experimental.v1._src.serialization import protocol_utils +from orbax.checkpoint.experimental.v1._src.serialization import registration from orbax.checkpoint.experimental.v1._src.serialization import types @@ -80,22 +81,7 @@ def _create_v0_array_handler( context: context_lib.Context, ) -> type_handlers_v0.ArrayHandler: """Creates a V0 array handler from a V1 context.""" - - saving_options = context.array_options.saving - primary_host = context.multiprocessing_options.primary_host - array_handler = type_handlers_v0.ArrayHandler( - primary_host=primary_host, - replica_id=None if primary_host is None else 0, - use_replica_parallel=saving_options.use_replica_parallel, - min_slice_bytes_for_replica_parallel=saving_options.min_slice_bytes_for_replica_parallel, - max_replicas_for_replica_parallel=saving_options.max_replicas_for_replica_parallel, - enable_replica_parallel_separate_folder=saving_options.enable_replica_parallel_separate_folder, - enable_write_sharding_file=saving_options.enable_write_sharding_file, - array_metadata_store=saving_options.array_metadata_store, - ) - - - return array_handler + return registration.get_array_handler(context) def _create_v0_saving_paraminfo( @@ -184,12 +170,17 @@ def _create_v0_restorearg( context: context_lib.Context, ) -> type_handlers_v0.ArrayRestoreArgs: """Creates a V0 `ArrayRestoreArgs` from V1 params.""" + restore_arg_cls = ( + type_handlers_v0.SingleReplicaArrayRestoreArgs + if context.array_options.loading.use_load_and_broadcast + else type_handlers_v0.ArrayRestoreArgs + ) value = param.value if value is None or isinstance(value, type): - return type_handlers_v0.ArrayRestoreArgs(restore_type=jax.Array) + return restore_arg_cls(restore_type=jax.Array) elif protocol_utils.is_subclass_protocol(value, AbstractShardedArray): value = typing.cast(AbstractShardedArray, value) - return type_handlers_v0.ArrayRestoreArgs( + return restore_arg_cls( restore_type=jax.Array, dtype=value.dtype, sharding=value.sharding, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py index 23e4ee7b5..e28ae7286 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py @@ -68,8 +68,7 @@ class NumpyMetadata(AbstractArray): def _create_v0_numpy_handler() -> type_handlers_v0.NumpyHandler: """Creates a V0 `NumpyHandler`.""" - numpy_handler = type_handlers_v0.NumpyHandler() - return numpy_handler + return registration.get_numpy_handler() def _create_v0_saving_paraminfo( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registration.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registration.py new file mode 100644 index 000000000..4430b386c --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registration.py @@ -0,0 +1,126 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines helpers for creating array `v0.TypeHandler`. + +Functions should return a `TypeHandler` that can be wrapped into a +:py:class:`LeafHandler`. It should return the appropriate handler based on +global settings and runtime (Pathways vs. mcJAX). + +This structure also helps prevent users from including Pathways dependencies in +their +binaries when they are not running on Pathways. The Pathways imports are +deferred until `is_pathways_backend()` returns True. + +Pathways dependencies should not be added to this file. +""" + +from orbax.checkpoint._src.serialization import jax_array_handlers +from orbax.checkpoint._src.serialization import pathways_handler_registry +from orbax.checkpoint._src.serialization import pathways_types +from orbax.checkpoint._src.serialization import type_handlers +from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.synchronization import multihost + +_PATHWAYS_IMPORT_ERROR_MSG = """ +Failed to import Pathways dependencies. Please ensure you have +linked the Pathways dependencies required by Orbax to your binary. +These are found at +//orbax/checkpoint/experimental/v1:pathways_support. +Please note that such dependencies are not linked automatically +because Pathways has a lot of dependencies, which non-Pathways +users wish to avoid linking. +""" + + +def resolve_pathways_checkpointing_impl( + context: context_lib.Context, +) -> pathways_types.CheckpointingImpl: + """Returns the Pathways checkpointing implementation.""" + try: + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from .learning.deepmind.jax.ocean.remote_python import rp + # pytype: enable=import-error + # pylint: enable=g-import-not-at-top + except ImportError as e: + raise ImportError(_PATHWAYS_IMPORT_ERROR_MSG) from e + checkpointing_impl = context.pathways_options.checkpointing_impl + return checkpointing_impl or pathways_types.CheckpointingImpl.from_options( + use_colocated_python=False, # Not enabled unless explicitly requested. + use_remote_python=rp.available(), + use_persistence_array_handler=True, # Only used as a fallback. + ) + + +def get_array_handler( + context: context_lib.Context, +) -> type_handlers.ArrayHandler: + """Returns the TypeHandler for JAX arrays (pytree leaves).""" + saving_options = context.array_options.saving + loading_options = context.array_options.loading + primary_host = context.multiprocessing_options.primary_host + common_kwargs = dict( + primary_host=primary_host, + replica_id=None if primary_host is None else 0, + use_replica_parallel=saving_options.use_replica_parallel, + min_slice_bytes_for_replica_parallel=saving_options.min_slice_bytes_for_replica_parallel, + max_replicas_for_replica_parallel=saving_options.max_replicas_for_replica_parallel, + enable_replica_parallel_separate_folder=saving_options.enable_replica_parallel_separate_folder, + enable_write_sharding_file=saving_options.enable_write_sharding_file, + array_metadata_store=saving_options.array_metadata_store, + ) + if loading_options.use_load_and_broadcast: + load_and_broadcast_kwargs = dict( + replica_axis_index=loading_options.load_and_broadcast_options.replica_axis_index, + primary_replica_id=loading_options.load_and_broadcast_options.primary_replica_id, + broadcast_memory_limit_bytes=loading_options.load_and_broadcast_options.broadcast_memory_limit_bytes, + broadcast_memory_scaling_factor=loading_options.load_and_broadcast_options.broadcast_memory_scaling_factor, + ) + else: + load_and_broadcast_kwargs = dict() + + if multihost.is_pathways_backend(): + checkpointing_impl = resolve_pathways_checkpointing_impl(context) + return pathways_handler_registry.get_pathways_array_handler( + use_single_replica_array_handler=loading_options.use_load_and_broadcast, + checkpointing_impl=checkpointing_impl, + **common_kwargs, + **load_and_broadcast_kwargs, + ) + else: + if loading_options.use_load_and_broadcast: + return jax_array_handlers.SingleReplicaArrayHandler( + dispatcher=None, + **common_kwargs, + **load_and_broadcast_kwargs, + ) + else: + return jax_array_handlers.ArrayHandler(dispatcher=None, **common_kwargs) + + +def get_numpy_handler() -> type_handlers.NumpyHandler: + """Returns the TypeHandler for Numpy arrays.""" + if multihost.is_pathways_backend(): + return pathways_handler_registry.get_pathways_numpy_handler() + else: + return type_handlers.NumpyHandler() + + +def get_scalar_handler() -> type_handlers.ScalarHandler: + """Returns the TypeHandler for scalars.""" + if multihost.is_pathways_backend(): + return pathways_handler_registry.get_pathways_scalar_handler() + else: + return type_handlers.ScalarHandler() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py index 96ed0163b..91a6d7fc0 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py @@ -39,8 +39,7 @@ def _create_v0_scalar_handler() -> type_handlers_v0.ScalarHandler: """Creates a V0 ScalarHandler.""" - scalar_handler = type_handlers_v0.ScalarHandler() - return scalar_handler + return registration.get_scalar_handler() def _create_v0_saving_paraminfo( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py index b76270ab1..5e01e4968 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py @@ -35,7 +35,7 @@ from orbax.checkpoint import test_utils from orbax.checkpoint._src.multihost import multihost as multihost_v0 from orbax.checkpoint._src.path import atomicity -from orbax.checkpoint._src.serialization import serialization +from orbax.checkpoint._src.serialization import serialization as serialization_v0 from orbax.checkpoint._src.tree import utils as tree_utils import orbax.checkpoint.experimental.v1 as ocp from orbax.checkpoint.experimental.v1._src.handlers import registration @@ -126,16 +126,16 @@ def test_load_default(self, use_async): def test_save_pytree_async(self): start_serialize = threading.Event() - original_serialize = serialization.async_serialize_from_host + original_serialize = serialization_v0.async_serialize_from_host def mock_serialize(*args, **kwargs): start_serialize.wait() # Wait for explicit signal before proceeding. return original_serialize(*args, **kwargs) - # Serialization to disk does not start until receiving an explicit signal. + # serialization to disk does not start until receiving an explicit signal. self.enter_context( mock.patch.object( - serialization, 'async_serialize_from_host', new=mock_serialize + serialization_v0, 'async_serialize_from_host', new=mock_serialize ) ) @@ -971,3 +971,39 @@ def test_save_checkpointables_directory_consistency_failure(self): ValueError, 'Directory path mismatch in multi-process save' ): ocp.save_pytree(directory, self.pytree) + + def test_load_and_broadcast(self): + replica_count = 2 + partition_count = jax.device_count() // replica_count + mesh = jax.sharding.Mesh( + np.asarray(jax.devices()).reshape(replica_count, partition_count), + ('replica', 'model'), + ) + spec = jax.sharding.PartitionSpec(None, 'model') + sharding = jax.sharding.NamedSharding(mesh, spec) + arr = array_test_utils.create_sharded_array( + np.arange(4 * 32).reshape(4, 32), sharding + ) + self.assertEqual( + sharding.shard_shape((4, 32)), (4, 32 // partition_count) + ) + with ocp.Context( + array_options=ocp.options.ArrayOptions( + loading=ocp.options.ArrayOptions.Loading( + use_load_and_broadcast=True, + ) + ) + ): + ocp.save_pytree(self.directory, [arr]) + with self.subTest('with_abstract_pytree'): + loaded = ocp.load_pytree( + self.directory, [array_test_utils.as_abstract_type(arr)] + ) + test_utils.assert_tree_equal(self, [arr], loaded) + with self.subTest('without_abstract_pytree'): + with self.assertRaisesRegex( + ValueError, + 'Must provide `sharding` to restore with' + ' `SingleReplicaArrayHandler`', + ): + ocp.load_pytree(self.directory) diff --git a/checkpoint/orbax/checkpoint/test_utils.py b/checkpoint/orbax/checkpoint/test_utils.py index fd3b330a0..db0c9d873 100644 --- a/checkpoint/orbax/checkpoint/test_utils.py +++ b/checkpoint/orbax/checkpoint/test_utils.py @@ -668,46 +668,14 @@ def create_single_replica_restore_args( arr: jax.Array, mesh: jax.sharding.Mesh, pspec: jax.sharding.PartitionSpec, - replica_axis_index: int, ): - replica_devices = _replica_devices(mesh.devices, replica_axis_index) - replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) - ss_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) - return type_handlers.SingleReplicaArrayRestoreArgs( sharding=jax.sharding.NamedSharding(mesh, pspec), - single_replica_sharding=ss_sharding, global_shape=arr.shape, dtype=arr.dtype, ) -def _find_idx(array: np.ndarray, replica_axis_idx: int): - """Returns the index along given dimension that the current host belongs to.""" - idx = None - for idx, val in np.ndenumerate(array): - if val.process_index == multihost.process_index(): - break - return idx[replica_axis_idx] - - -def _replica_devices(device_array: np.ndarray, replica_axis_idx: int): - """Returns the devices from the replica that current host belongs to. - - Replicas are assumed to be restricted to the first axis. - - Args: - device_array: devices of the mesh that can be obtained by mesh.devices() - replica_axis_idx: axis dimension along which replica is taken - - Returns: - devices inside the replica that current host is in - """ - idx = _find_idx(device_array, replica_axis_idx) - replica_result = np.take(device_array, idx, axis=replica_axis_idx) - return np.expand_dims(replica_result, axis=replica_axis_idx) - - class TestLimitInFlightBytes(limits.LimitInFlightBytes): """Limits in-flight bytes when reading/writing checkpoints per process."""