Skip to content

Commit 40b467f

Browse files
committed
Added persistence API read calls to deserialize
1 parent c4e6767 commit 40b467f

1 file changed

Lines changed: 61 additions & 1 deletion

File tree

axlearn/common/array_serialization.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""
1919
import asyncio
2020
import functools
21+
import itertools
2122
import math
2223
import os
2324
import threading
@@ -26,6 +27,7 @@
2627
from concurrent import futures
2728
from concurrent.futures import ThreadPoolExecutor
2829
from dataclasses import dataclass
30+
import datetime
2931
from typing import Any, Callable, Optional, Sequence, Union
3032

3133
import jax
@@ -36,6 +38,7 @@
3638
from jax._src import array, typing
3739
from jax._src.layout import Layout
3840
from jax.experimental.array_serialization import serialization
41+
from pathwaysutils.persistence import helper
3942

4043
from axlearn.common.utils import Tensor
4144

@@ -578,9 +581,66 @@ async def _run_serializer():
578581
# has finished writing.
579582
self._start_async_commit(on_commit_callback)
580583

584+
def deserialize(
585+
self,
586+
shardings: Sequence[Union[jax.sharding.Sharding, Layout]],
587+
tensorstore_specs: Sequence[dict[str, Any]],
588+
global_shapes: Optional[Sequence[array.Shape]] = None,
589+
dtypes: Optional[Sequence[typing.DTypeLike]] = None,
590+
concurrent_gb: int = 32,
591+
):
592+
self.wait_until_finished()
593+
594+
bulk_read_inputs = defaultdict(
595+
lambda: dict(indices=[], names=[], dtypes=[], shapes=[], shardings=[])
596+
)
597+
598+
for index, tensorstore_spec, sharding, global_shape, dtype in enumerate(itertools.zip_longest(tensorstore_specs, shardings, global_shapes, dtypes)):
599+
sharding = (sharding.sharding if isinstance(sharding, Layout) else sharding)
600+
if not isinstance(sharding, jax.sharding.Sharding):
601+
raise ValueError(
602+
"sharding passed to deserialization should be specified, concrete and"
603+
f" an instance of `jax.sharding.Sharding`. Got {sharding}"
604+
)
605+
606+
kvstore = tensorstore_spec["kvstore"]
607+
if kvstore.get("driver", "") != "gcs":
608+
raise ValueError("Only GCS backed checkpoints have been tested")
609+
610+
bucket = kvstore["bucket"]
611+
path, name = os.path.split(kvstore["path"])
612+
location = f"gs://{bucket}/{path}"
613+
shape = tensorstore_spec["metadata"]["shape"] if global_shape is None else global_shape
614+
dtype = tensorstore_spec["metadata"]["dtype"] if dtype is None else dtype
615+
616+
key = (location, sharding.mesh)
617+
bulk_read_inputs[key]["indices"].append(index)
618+
bulk_read_inputs[key]["names"].append(name)
619+
bulk_read_inputs[key]["dtypes"].append(dtype)
620+
bulk_read_inputs[key]["shapes"].append(shape)
621+
bulk_read_inputs[key]["shardings"].append(sharding)
622+
623+
unsorted_results = []
624+
for key, bulk_read_input in bulk_read_inputs.items():
625+
location, mesh = key
626+
arrays, read_future = helper.read_arrays(
627+
location=location,
628+
names=bulk_read_input["names"],
629+
dtypes=bulk_read_input["dtypes"],
630+
shapes=bulk_read_input["shapes"],
631+
shardings=bulk_read_input["shardings"],
632+
devices=mesh.devices,
633+
timeout=datetime.timedelta(minutes=5),
634+
)
635+
read_future.result()
636+
637+
unsorted_results.extend(list(zip(bulk_read_inputs["indices"], arrays)))
638+
639+
return [array for _, array in sorted(unsorted_results)]
640+
581641
# Copied from (with modifications)
582642
# https://github.com/jax-ml/jax/blob/66037d10e7742c4fcadd07f0459a00813ec7ed5f/jax/experimental/array_serialization/serialization.py#L413-L429
583-
def deserialize(
643+
def deserialize_original(
584644
self,
585645
shardings: Sequence[Union[jax.sharding.Sharding, Layout]],
586646
tensorstore_specs: Sequence[dict[str, Any]],

0 commit comments

Comments
 (0)