From 0c8d5ab7ad3f0d1f5f8833a481c34f90721246e0 Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Thu, 19 Feb 2026 03:40:24 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 872315040 --- .../orbax/checkpoint/_src/arrays/fragments.py | 41 +++- .../checkpoint/_src/arrays/fragments_test.py | 181 ++++++++++++++++++ 2 files changed, 221 insertions(+), 1 deletion(-) diff --git a/checkpoint/orbax/checkpoint/_src/arrays/fragments.py b/checkpoint/orbax/checkpoint/_src/arrays/fragments.py index ef369ebbb..cefbd6d74 100644 --- a/checkpoint/orbax/checkpoint/_src/arrays/fragments.py +++ b/checkpoint/orbax/checkpoint/_src/arrays/fragments.py @@ -24,12 +24,19 @@ class methods: FS is `AbstractFragments` then x may additionally be a `jax.ShapeDtypeStruct`. - `none_of(x)`: gives an FS shaped like x, with no fragments. + - `addressable_shards_of(x)`: gives an FS shaped like x, with one fragment for + each distinct addressable shard of x. If FS is `AbstractFragments` then x + may additionally be a `jax.ShapeDtypeStruct`. + - `of(x, *indices)`: gives an FS shaped like x, with one fragment for each + index in `indices`. If FS is concrete then the fragment values will be + slices of x. If FS is `AbstractFragments` then x may additionally be + a `jax.ShapeDtypeStruct`. """ # TODO(b/465196209): Remove when support for Python 3.10 is dropped. from __future__ import annotations import dataclasses -from typing import Any, ClassVar, Generic, Literal, Sequence, TypeAlias, TypeVar +from typing import Any, ClassVar, Generator, Generic, Literal, Sequence, TypeAlias, TypeVar import jax import numpy as np @@ -369,6 +376,16 @@ def none_of(cls: type[FS], x: Any) -> FS: """Returns a Fragments with no fragments.""" return cls._of(x, indices=[]) + @classmethod + def addressable_shards_of(cls: type[FS], x: Any) -> FS: + """Returns a Fragments exactly spanning the distinct addressable shards of `x`.""" + return cls._of(x, indices=_gen_distinct_addressable_indices(x)) + + @classmethod + def of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS: + """Returns a Fragments exactly spanning the given indices.""" + return cls._of(x, indices=indices) + def is_degenerate(self) -> bool: """Whether this contains only degenerate fragments.""" return all(f.is_degenerate() for f in self.fragments) @@ -534,6 +551,28 @@ def addressable_shards(x: jax.Array | jax.ShapeDtypeStruct) -> list[Index]: ] +def _gen_distinct_addressable_indices( + x: np.ndarray | jax.Array | jax.ShapeDtypeStruct, +) -> Generator[Index, None, None]: + """Yields fragment indices for distinct addressable shards of x.""" + match x: + case jax.Array() | jax.ShapeDtypeStruct(): + if not x.sharding: + raise ValueError( + 'Cannot determine addressable shards of jax.ShapeDtypeStruct with' + ' no sharding.' + ) + indices = addressable_shards(x) + case np.ndarray(): + indices = (tuple(slice(0, dim, 1) for dim in x.shape),) + case _: + raise TypeError(f'Unsupported type: {type(x)}') + distinct_indices = sorted({ + *(np_utils.to_hashable_index(i, shape=x.shape) for i in indices) + }) + yield from (np_utils.from_hashable_index(i) for i in distinct_indices) + + def abstract_fragments( x: jax.Array | jax.ShapeDtypeStruct | FS, ) -> AbstractFragments: diff --git a/checkpoint/orbax/checkpoint/_src/arrays/fragments_test.py b/checkpoint/orbax/checkpoint/_src/arrays/fragments_test.py index f2dfe2e8c..233ef5abf 100644 --- a/checkpoint/orbax/checkpoint/_src/arrays/fragments_test.py +++ b/checkpoint/orbax/checkpoint/_src/arrays/fragments_test.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses +from unittest import mock from absl.testing import absltest from absl.testing import parameterized @@ -20,6 +21,7 @@ import jax.numpy as jnp import numpy as np from orbax.checkpoint._src.arrays import fragments as array_fragments +from orbax.checkpoint._src.arrays import numpy_utils as np_utils AbstractFragment = array_fragments.AbstractFragment AbstractFragments = array_fragments.AbstractFragments @@ -647,6 +649,27 @@ def test_non_full_fragments_raises_exception( np.asarray(fragments) +def _fake_sharding(shape) -> jax.sharding.NamedSharding: + sharding = mock.Mock(spec=jax.sharding.NamedSharding) + # List each slice twice to all tests to show that they're deduplicated. + sharding.addressable_devices_indices_map.return_value = dict( + (mock.Mock(), np.s_[0:2:1, y:y+1:1]) for y in [*range(shape[1])[::2]] * 2 + ) + return sharding + + +def _fake_jnp_ones(shape, dtype, sharding) -> jax.Array: + return mock.Mock( + spec=jax.Array, + shape=shape, + dtype=dtype, + sharding=sharding, + __getitem__=lambda self, index: jnp.ones( + np_utils.slice_shape(index), dtype=dtype + ), + ) + + class FragmentsClassMethodsTest(parameterized.TestCase): @parameterized.named_parameters( @@ -746,6 +769,164 @@ def test_abstract_fragments_from_none_of(self, value_fn): with self.assertRaisesRegex(TypeError, 'Fragment value must be'): fragments_t.none_of(None) # pytype: disable=wrong-arg-types + def test_np_fragments_from_addressable_shards_of_np_array(self): + # NumPy arrays aren't sharded so we expect a single fragment spanning the + # entire array, the same as if we had used `all_of()`. + fragments_t = NpFragments + shape = (2, 3) + dtype = np.dtype(np.float32) + fragment_t = fragments_t.FRAGMENT_T + np_api = fragment_t.NP_API + other_np_api = np if np_api is jnp else jnp + + with self.subTest('with_array'): + x = np_api.ones(shape, dtype=dtype) + fs = fragments_t.addressable_shards_of(x) + self.assertEqual(fs.shape, shape) + self.assertEqual(fs.dtype, dtype) + self.assertEqual( + fs.fragments, [fragment_t(index=np.s_[0:2:1, 0:3:1], value=x)] + ) + + with self.subTest('with_wrong_array_type_raises'): + with self.assertRaisesRegex(TypeError, 'Fragment value must be'): + fragments_t.addressable_shards_of(other_np_api.ones(shape, dtype=dtype)) + + with self.subTest('with_shape_dtype_struct_raises'): + x = jax.ShapeDtypeStruct(shape, dtype) + with self.assertRaisesRegex(TypeError, 'Fragment value must be'): + fragments_t.addressable_shards_of(x) + + def test_jax_fragments_from_addressable_shards_of_jnp_array(self): + fragments_t = JaxFragments + shape = (2, 3) + dtype = np.dtype(np.float32) + fragment_t = fragments_t.FRAGMENT_T + np_api = fragment_t.NP_API + other_np_api = np if np_api is jnp else jnp + + with self.subTest('with_array'): + x = _fake_jnp_ones( + shape=shape, dtype=dtype, sharding=_fake_sharding(shape) + ) + fs = fragments_t.addressable_shards_of(x) + self.assertEqual(fs.shape, shape) + self.assertEqual(fs.dtype, dtype) + self.assertEqual(fs.fragments, [ + fragment_t(index=np.s_[0:2:1, 0:1:1], value=x[0:2:1, 0:1:1]), + fragment_t(index=np.s_[0:2:1, 2:3:1], value=x[0:2:1, 2:3:1]), + ]) + + with self.subTest('with_wrong_array_type_raises'): + with self.assertRaisesRegex(TypeError, 'Fragment value must be'): + fragments_t.addressable_shards_of(other_np_api.ones(shape, dtype=dtype)) + + with self.subTest('with_shape_dtype_struct_raises'): + x = jax.ShapeDtypeStruct(shape, dtype) + with self.assertRaisesRegex(TypeError, 'Fragment value must be'): + fragments_t.addressable_shards_of(x) + + def test_abstract_fragments_from_addressable_shards_of_np_array(self): + # NumPy arrays aren't sharded so we expect a single fragment spanning the + # entire array, the same as if we had used `all_of()`. + fragments_t = AbstractFragments + shape = (2, 3) + dtype = np.dtype(np.float32) + fragment_t = fragments_t.FRAGMENT_T + + x = np.ones(shape, dtype=dtype) + fs = fragments_t.addressable_shards_of(x) + self.assertEqual(fs.shape, shape) + self.assertEqual(fs.dtype, dtype) + self.assertEqual(fs.fragments, [fragment_t(index=np.s_[0:2:1, 0:3:1])]) + + @parameterized.named_parameters( + ('shape_dtype_struct', jax.ShapeDtypeStruct), + ('jnp_array', _fake_jnp_ones), + ) + def test_abstract_fragments_from_addressable_shards_of(self, value_fn): + fragments_t = AbstractFragments + shape = (2, 3) + dtype = np.dtype(np.float32) + fragment_t = fragments_t.FRAGMENT_T + + x = value_fn(shape, dtype=dtype, sharding=_fake_sharding(shape)) + fs = fragments_t.addressable_shards_of(x) + self.assertEqual(fs.shape, shape) + self.assertEqual(fs.dtype, dtype) + self.assertEqual(fs.fragments, [ + fragment_t(index=np.s_[0:2:1, 0:1:1]), + fragment_t(index=np.s_[0:2:1, 2:3:1]), + ]) + + def test_abstract_fragments_from_addressable_shards_of_shape_dtype_struct_with_no_sharding_raises( + self, + ): + fragments_t = AbstractFragments + shape = (2, 3) + dtype = np.dtype(np.float32) + + x = jax.ShapeDtypeStruct(shape, dtype=dtype) + with self.assertRaisesRegex( + ValueError, 'Cannot determine addressable shards' + ): + fragments_t.addressable_shards_of(x) + + @parameterized.named_parameters( + ('np_array', NpFragments), + ('jnp_array', JaxFragments), + ) + def test_concrete_fragments_of(self, fragments_t: FragmentsT): + shape = (2, 3) + dtype = np.dtype(np.float32) + fragment_t = fragments_t.FRAGMENT_T + np_api = fragment_t.NP_API + other_np_api = np if np_api is jnp else jnp + + with self.subTest('with_array'): + x = np_api.ones(shape, dtype=dtype) + fs = fragments_t.of(x, indices=[np.s_[0:2:1, 0:1:1], np.s_[0:2:1, 2:3:1]]) + self.assertEqual(fs.shape, shape) + self.assertEqual(fs.dtype, dtype) + self.assertEqual( + fs.fragments, [ + fragment_t(index=np.s_[0:2:1, 0:1:1], value=x[0:2:1, 0:1:1]), + fragment_t(index=np.s_[0:2:1, 2:3:1], value=x[0:2:1, 2:3:1]), + ] + ) + + with self.subTest('with_wrong_array_type_raises'): + with self.assertRaisesRegex(TypeError, 'Fragment value must be'): + fragments_t.of( + other_np_api.ones(shape, dtype=dtype), + indices=[np.s_[0:2:1, 0:1:1], np.s_[0:2:1, 2:3:1]], + ) + + with self.subTest('with_shape_dtype_struct_raises'): + x = jax.ShapeDtypeStruct(shape, dtype) + with self.assertRaisesRegex(TypeError, 'Fragment value must be'): + fragments_t.of(x, indices=[np.s_[0:2:1, 0:1:1], np.s_[0:2:1, 2:3:1]]) + + @parameterized.named_parameters( + ('shape_dtype_struct', jax.ShapeDtypeStruct), + ('np_array', np.ones), + ('jnp_array', jnp.ones), + ) + def test_abstract_fragments_of(self, value_fn): + fragments_t = AbstractFragments + shape = (2, 3) + dtype = np.dtype(np.float32) + fragment_t = fragments_t.FRAGMENT_T + + x = value_fn(shape, dtype=dtype) + fs = fragments_t.of(x, indices=[np.s_[0:2:1, 0:1:1], np.s_[0:2:1, 2:3:1]]) + self.assertEqual(fs.shape, shape) + self.assertEqual(fs.dtype, dtype) + self.assertEqual(fs.fragments, [ + fragment_t(index=np.s_[0:2:1, 0:1:1]), + fragment_t(index=np.s_[0:2:1, 2:3:1]), + ]) + @parameterized.named_parameters( ('abstract_fragments', AbstractFragments),