|
18 | 18 | """ |
19 | 19 | import asyncio |
20 | 20 | import functools |
| 21 | +import itertools |
21 | 22 | import math |
22 | 23 | import os |
23 | 24 | import threading |
|
26 | 27 | from concurrent import futures |
27 | 28 | from concurrent.futures import ThreadPoolExecutor |
28 | 29 | from dataclasses import dataclass |
| 30 | +import datetime |
29 | 31 | from typing import Any, Callable, Optional, Sequence, Union |
30 | 32 |
|
31 | 33 | import jax |
|
36 | 38 | from jax._src import array, typing |
37 | 39 | from jax._src.layout import Layout |
38 | 40 | from jax.experimental.array_serialization import serialization |
| 41 | +from pathwaysutils.persistence import helper |
39 | 42 |
|
40 | 43 | from axlearn.common.utils import Tensor |
41 | 44 |
|
@@ -578,9 +581,66 @@ async def _run_serializer(): |
578 | 581 | # has finished writing. |
579 | 582 | self._start_async_commit(on_commit_callback) |
580 | 583 |
|
| 584 | + def deserialize( |
| 585 | + self, |
| 586 | + shardings: Sequence[Union[jax.sharding.Sharding, Layout]], |
| 587 | + tensorstore_specs: Sequence[dict[str, Any]], |
| 588 | + global_shapes: Optional[Sequence[array.Shape]] = None, |
| 589 | + dtypes: Optional[Sequence[typing.DTypeLike]] = None, |
| 590 | + concurrent_gb: int = 32, |
| 591 | + ): |
| 592 | + self.wait_until_finished() |
| 593 | + |
| 594 | + bulk_read_inputs = defaultdict( |
| 595 | + lambda: dict(indices=[], names=[], dtypes=[], shapes=[], shardings=[]) |
| 596 | + ) |
| 597 | + |
| 598 | + for index, tensorstore_spec, sharding, global_shape, dtype in enumerate(itertools.zip_longest(tensorstore_specs, shardings, global_shapes, dtypes)): |
| 599 | + sharding = (sharding.sharding if isinstance(sharding, Layout) else sharding) |
| 600 | + if not isinstance(sharding, jax.sharding.Sharding): |
| 601 | + raise ValueError( |
| 602 | + "sharding passed to deserialization should be specified, concrete and" |
| 603 | + f" an instance of `jax.sharding.Sharding`. Got {sharding}" |
| 604 | + ) |
| 605 | + |
| 606 | + kvstore = tensorstore_spec["kvstore"] |
| 607 | + if kvstore.get("driver", "") != "gcs": |
| 608 | + raise ValueError("Only GCS backed checkpoints have been tested") |
| 609 | + |
| 610 | + bucket = kvstore["bucket"] |
| 611 | + path, name = os.path.split(kvstore["path"]) |
| 612 | + location = f"gs://{bucket}/{path}" |
| 613 | + shape = tensorstore_spec["metadata"]["shape"] if global_shape is None else global_shape |
| 614 | + dtype = tensorstore_spec["metadata"]["dtype"] if dtype is None else dtype |
| 615 | + |
| 616 | + key = (location, sharding.mesh) |
| 617 | + bulk_read_inputs[key]["indices"].append(index) |
| 618 | + bulk_read_inputs[key]["names"].append(name) |
| 619 | + bulk_read_inputs[key]["dtypes"].append(dtype) |
| 620 | + bulk_read_inputs[key]["shapes"].append(shape) |
| 621 | + bulk_read_inputs[key]["shardings"].append(sharding) |
| 622 | + |
| 623 | + unsorted_results = [] |
| 624 | + for key, bulk_read_input in bulk_read_inputs.items(): |
| 625 | + location, mesh = key |
| 626 | + arrays, read_future = helper.read_arrays( |
| 627 | + location=location, |
| 628 | + names=bulk_read_input["names"], |
| 629 | + dtypes=bulk_read_input["dtypes"], |
| 630 | + shapes=bulk_read_input["shapes"], |
| 631 | + shardings=bulk_read_input["shardings"], |
| 632 | + devices=mesh.devices, |
| 633 | + timeout=datetime.timedelta(minutes=5), |
| 634 | + ) |
| 635 | + read_future.result() |
| 636 | + |
| 637 | + unsorted_results.extend(list(zip(bulk_read_inputs["indices"], arrays))) |
| 638 | + |
| 639 | + return [array for _, array in sorted(unsorted_results)] |
| 640 | + |
581 | 641 | # Copied from (with modifications) |
582 | 642 | # https://github.com/jax-ml/jax/blob/66037d10e7742c4fcadd07f0459a00813ec7ed5f/jax/experimental/array_serialization/serialization.py#L413-L429 |
583 | | - def deserialize( |
| 643 | + def deserialize_original( |
584 | 644 | self, |
585 | 645 | shardings: Sequence[Union[jax.sharding.Sharding, Layout]], |
586 | 646 | tensorstore_specs: Sequence[dict[str, Any]], |
|
0 commit comments