Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion checkpoint/orbax/checkpoint/_src/arrays/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,19 @@ class methods:
FS is `AbstractFragments` then x may additionally be a
`jax.ShapeDtypeStruct`.
- `none_of(x)`: gives an FS shaped like x, with no fragments.
- `addressable_shards_of(x)`: gives an FS shaped like x, with one fragment for
each distinct addressable shard of x. If FS is `AbstractFragments` then x
may additionally be a `jax.ShapeDtypeStruct`.
- `of(x, *indices)`: gives an FS shaped like x, with one fragment for each
index in `indices`. If FS is concrete then the fragment values will be
slices of x. If FS is `AbstractFragments` then x may additionally be
a `jax.ShapeDtypeStruct`.
"""
# TODO(b/465196209): Remove when support for Python 3.10 is dropped.
from __future__ import annotations

import dataclasses
from typing import Any, ClassVar, Generic, Literal, Sequence, TypeAlias, TypeVar
from typing import Any, ClassVar, Generator, Generic, Literal, Sequence, TypeAlias, TypeVar

import jax
import numpy as np
Expand Down Expand Up @@ -369,6 +376,16 @@ def none_of(cls: type[FS], x: Any) -> FS:
"""Returns a Fragments with no fragments."""
return cls._of(x, indices=[])

@classmethod
def addressable_shards_of(cls: type[FS], x: Any) -> FS:
"""Returns a Fragments exactly spanning the distinct addressable shards of `x`."""
return cls._of(x, indices=_gen_distinct_addressable_indices(x))

@classmethod
def of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
"""Returns a Fragments exactly spanning the given indices."""
return cls._of(x, indices=indices)

def is_degenerate(self) -> bool:
"""Whether this contains only degenerate fragments."""
return all(f.is_degenerate() for f in self.fragments)
Expand Down Expand Up @@ -534,6 +551,28 @@ def addressable_shards(x: jax.Array | jax.ShapeDtypeStruct) -> list[Index]:
]


def _gen_distinct_addressable_indices(
x: np.ndarray | jax.Array | jax.ShapeDtypeStruct,
) -> Generator[Index, None, None]:
"""Yields fragment indices for distinct addressable shards of x."""
match x:
case jax.Array() | jax.ShapeDtypeStruct():
if not x.sharding:
raise ValueError(
'Cannot determine addressable shards of jax.ShapeDtypeStruct with'
' no sharding.'
)
indices = addressable_shards(x)
case np.ndarray():
indices = (tuple(slice(0, dim, 1) for dim in x.shape),)
case _:
raise TypeError(f'Unsupported type: {type(x)}')
distinct_indices = sorted({
*(np_utils.to_hashable_index(i, shape=x.shape) for i in indices)
})
yield from (np_utils.from_hashable_index(i) for i in distinct_indices)


def abstract_fragments(
x: jax.Array | jax.ShapeDtypeStruct | FS,
) -> AbstractFragments:
Expand Down
181 changes: 181 additions & 0 deletions checkpoint/orbax/checkpoint/_src/arrays/fragments_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.

import dataclasses
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint._src.arrays import fragments as array_fragments
from orbax.checkpoint._src.arrays import numpy_utils as np_utils

AbstractFragment = array_fragments.AbstractFragment
AbstractFragments = array_fragments.AbstractFragments
Expand Down Expand Up @@ -647,6 +649,27 @@ def test_non_full_fragments_raises_exception(
np.asarray(fragments)


def _fake_sharding(shape) -> jax.sharding.NamedSharding:
sharding = mock.Mock(spec=jax.sharding.NamedSharding)
# List each slice twice to all tests to show that they're deduplicated.
sharding.addressable_devices_indices_map.return_value = dict(
(mock.Mock(), np.s_[0:2:1, y:y+1:1]) for y in [*range(shape[1])[::2]] * 2
)
return sharding


def _fake_jnp_ones(shape, dtype, sharding) -> jax.Array:
return mock.Mock(
spec=jax.Array,
shape=shape,
dtype=dtype,
sharding=sharding,
__getitem__=lambda self, index: jnp.ones(
np_utils.slice_shape(index), dtype=dtype
),
)


class FragmentsClassMethodsTest(parameterized.TestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -746,6 +769,164 @@ def test_abstract_fragments_from_none_of(self, value_fn):
with self.assertRaisesRegex(TypeError, 'Fragment value must be'):
fragments_t.none_of(None) # pytype: disable=wrong-arg-types

def test_np_fragments_from_addressable_shards_of_np_array(self):
# NumPy arrays aren't sharded so we expect a single fragment spanning the
# entire array, the same as if we had used `all_of()`.
fragments_t = NpFragments
shape = (2, 3)
dtype = np.dtype(np.float32)
fragment_t = fragments_t.FRAGMENT_T
np_api = fragment_t.NP_API
other_np_api = np if np_api is jnp else jnp

with self.subTest('with_array'):
x = np_api.ones(shape, dtype=dtype)
fs = fragments_t.addressable_shards_of(x)
self.assertEqual(fs.shape, shape)
self.assertEqual(fs.dtype, dtype)
self.assertEqual(
fs.fragments, [fragment_t(index=np.s_[0:2:1, 0:3:1], value=x)]
)

with self.subTest('with_wrong_array_type_raises'):
with self.assertRaisesRegex(TypeError, 'Fragment value must be'):
fragments_t.addressable_shards_of(other_np_api.ones(shape, dtype=dtype))

with self.subTest('with_shape_dtype_struct_raises'):
x = jax.ShapeDtypeStruct(shape, dtype)
with self.assertRaisesRegex(TypeError, 'Fragment value must be'):
fragments_t.addressable_shards_of(x)

def test_jax_fragments_from_addressable_shards_of_jnp_array(self):
fragments_t = JaxFragments
shape = (2, 3)
dtype = np.dtype(np.float32)
fragment_t = fragments_t.FRAGMENT_T
np_api = fragment_t.NP_API
other_np_api = np if np_api is jnp else jnp

with self.subTest('with_array'):
x = _fake_jnp_ones(
shape=shape, dtype=dtype, sharding=_fake_sharding(shape)
)
fs = fragments_t.addressable_shards_of(x)
self.assertEqual(fs.shape, shape)
self.assertEqual(fs.dtype, dtype)
self.assertEqual(fs.fragments, [
fragment_t(index=np.s_[0:2:1, 0:1:1], value=x[0:2:1, 0:1:1]),
fragment_t(index=np.s_[0:2:1, 2:3:1], value=x[0:2:1, 2:3:1]),
])

with self.subTest('with_wrong_array_type_raises'):
with self.assertRaisesRegex(TypeError, 'Fragment value must be'):
fragments_t.addressable_shards_of(other_np_api.ones(shape, dtype=dtype))

with self.subTest('with_shape_dtype_struct_raises'):
x = jax.ShapeDtypeStruct(shape, dtype)
with self.assertRaisesRegex(TypeError, 'Fragment value must be'):
fragments_t.addressable_shards_of(x)

def test_abstract_fragments_from_addressable_shards_of_np_array(self):
# NumPy arrays aren't sharded so we expect a single fragment spanning the
# entire array, the same as if we had used `all_of()`.
fragments_t = AbstractFragments
shape = (2, 3)
dtype = np.dtype(np.float32)
fragment_t = fragments_t.FRAGMENT_T

x = np.ones(shape, dtype=dtype)
fs = fragments_t.addressable_shards_of(x)
self.assertEqual(fs.shape, shape)
self.assertEqual(fs.dtype, dtype)
self.assertEqual(fs.fragments, [fragment_t(index=np.s_[0:2:1, 0:3:1])])

@parameterized.named_parameters(
('shape_dtype_struct', jax.ShapeDtypeStruct),
('jnp_array', _fake_jnp_ones),
)
def test_abstract_fragments_from_addressable_shards_of(self, value_fn):
fragments_t = AbstractFragments
shape = (2, 3)
dtype = np.dtype(np.float32)
fragment_t = fragments_t.FRAGMENT_T

x = value_fn(shape, dtype=dtype, sharding=_fake_sharding(shape))
fs = fragments_t.addressable_shards_of(x)
self.assertEqual(fs.shape, shape)
self.assertEqual(fs.dtype, dtype)
self.assertEqual(fs.fragments, [
fragment_t(index=np.s_[0:2:1, 0:1:1]),
fragment_t(index=np.s_[0:2:1, 2:3:1]),
])

def test_abstract_fragments_from_addressable_shards_of_shape_dtype_struct_with_no_sharding_raises(
self,
):
fragments_t = AbstractFragments
shape = (2, 3)
dtype = np.dtype(np.float32)

x = jax.ShapeDtypeStruct(shape, dtype=dtype)
with self.assertRaisesRegex(
ValueError, 'Cannot determine addressable shards'
):
fragments_t.addressable_shards_of(x)

@parameterized.named_parameters(
('np_array', NpFragments),
('jnp_array', JaxFragments),
)
def test_concrete_fragments_of(self, fragments_t: FragmentsT):
shape = (2, 3)
dtype = np.dtype(np.float32)
fragment_t = fragments_t.FRAGMENT_T
np_api = fragment_t.NP_API
other_np_api = np if np_api is jnp else jnp

with self.subTest('with_array'):
x = np_api.ones(shape, dtype=dtype)
fs = fragments_t.of(x, indices=[np.s_[0:2:1, 0:1:1], np.s_[0:2:1, 2:3:1]])
self.assertEqual(fs.shape, shape)
self.assertEqual(fs.dtype, dtype)
self.assertEqual(
fs.fragments, [
fragment_t(index=np.s_[0:2:1, 0:1:1], value=x[0:2:1, 0:1:1]),
fragment_t(index=np.s_[0:2:1, 2:3:1], value=x[0:2:1, 2:3:1]),
]
)

with self.subTest('with_wrong_array_type_raises'):
with self.assertRaisesRegex(TypeError, 'Fragment value must be'):
fragments_t.of(
other_np_api.ones(shape, dtype=dtype),
indices=[np.s_[0:2:1, 0:1:1], np.s_[0:2:1, 2:3:1]],
)

with self.subTest('with_shape_dtype_struct_raises'):
x = jax.ShapeDtypeStruct(shape, dtype)
with self.assertRaisesRegex(TypeError, 'Fragment value must be'):
fragments_t.of(x, indices=[np.s_[0:2:1, 0:1:1], np.s_[0:2:1, 2:3:1]])

@parameterized.named_parameters(
('shape_dtype_struct', jax.ShapeDtypeStruct),
('np_array', np.ones),
('jnp_array', jnp.ones),
)
def test_abstract_fragments_of(self, value_fn):
fragments_t = AbstractFragments
shape = (2, 3)
dtype = np.dtype(np.float32)
fragment_t = fragments_t.FRAGMENT_T

x = value_fn(shape, dtype=dtype)
fs = fragments_t.of(x, indices=[np.s_[0:2:1, 0:1:1], np.s_[0:2:1, 2:3:1]])
self.assertEqual(fs.shape, shape)
self.assertEqual(fs.dtype, dtype)
self.assertEqual(fs.fragments, [
fragment_t(index=np.s_[0:2:1, 0:1:1]),
fragment_t(index=np.s_[0:2:1, 2:3:1]),
])


@parameterized.named_parameters(
('abstract_fragments', AbstractFragments),
Expand Down
Loading