Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,69 @@
from etils import epath
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.path import types as path_types


class AsyncCheckpointHandler(checkpoint_handler.CheckpointHandler):
"""An interface providing async methods that can be used with CheckpointHandler."""
"""An interface providing async methods used with AsyncCheckpointer."""

@abc.abstractmethod
async def async_save(
self, directory: epath.Path, *args, **kwargs
self,
directory: epath.Path,
*args,
**kwargs,
) -> Optional[List[future.Future]]:
"""Constructs a save operation.
"""Saves the given item to the provided directory.

Synchronously awaits a copy of the item, before returning commit futures
necessary to save the item.
Args:
directory: the directory to save to.
*args: additional arguments for save.
**kwargs: additional arguments for save.

Returns:
A list of commit futures which can be awaited upon to complete the save
operation.
"""
pass


class DeferredPathAsyncCheckpointHandler(AsyncCheckpointHandler):
"""Handler interface that receives Path or PathAwaitingCreation.

Note: Any operations on directory should be done by using
`future.CommitFutureAwaitingContractedSignals` to wait for directories to be
created.
This interface extends AsyncCheckpointHandler with an async_save method that
accepts either an epath.Path or PathAwaitingCreation, allowing handlers to
work with deferred paths (e.g., TFHub) where the actual path is allocated
asynchronously.

Handlers implementing this interface can:
1. Receive a deferred path representation before the path is allocated
2. Wait for STEP_DIRECTORY_CREATION signal inside their CommitFuture
3. Access the path via await_creation() or .path after the signal
"""

@abc.abstractmethod
async def async_save(
self,
directory: epath.Path | path_types.PathAwaitingCreation,
*args,
**kwargs,
) -> Optional[List[future.Future]]:
"""Constructs a save operation with support for deferred paths.

This method accepts an epath.Path or PathAwaitingCreation.
When a deferred path is passed, handler coroutines should wait for the
STEP_DIRECTORY_CREATION signal before accessing the path.

Args:
directory: the directory to save to.
directory: The directory to save to. May be an epath.Path or
PathAwaitingCreation. For deferred paths, await_creation() or signal
ordering ensures the path is available.
*args: additional arguments for save.
**kwargs: additional arguments for save.

Returns:
A list of futures that will commit the data when awaited.
"""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import format_utils
from orbax.checkpoint._src.path import types as path_types
from orbax.checkpoint._src.serialization import limits
from orbax.checkpoint._src.serialization import ocdbt_utils
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
Expand Down Expand Up @@ -308,7 +309,7 @@ def _format_bytes(bytes_value: Optional[int]) -> str:


class BasePyTreeCheckpointHandler(
async_checkpoint_handler.AsyncCheckpointHandler
async_checkpoint_handler.DeferredPathAsyncCheckpointHandler
):
"""A CheckpointHandler implementation for any PyTree structure.

Expand Down Expand Up @@ -433,7 +434,7 @@ def get_param_names(self, item: PyTree) -> PyTree:
def _get_param_infos(
self,
item: PyTree,
directory: epath.Path,
directory: epath.Path | path_types.PathAwaitingCreation,
*,
use_ocdbt: bool = True,
use_compression: bool | None = True,
Expand Down Expand Up @@ -581,7 +582,7 @@ def _handle_diffs(keypath, diff):

async def async_save(
self,
directory: epath.Path,
directory: epath.Path | path_types.PathAwaitingCreation,
args: BasePyTreeSaveArgs,
) -> Optional[List[future.Future]]:
"""Saves a PyTree to a given directory.
Expand Down Expand Up @@ -644,7 +645,7 @@ async def async_save(
use_zarr3=self._use_zarr3,
)
assert all(
leaf.parent_dir == directory for leaf in jax.tree.leaves(param_infos)
leaf.parent_dir is directory for leaf in jax.tree.leaves(param_infos)
)

serialize_ops = [] # List of (coros -> List of futures)
Expand All @@ -660,11 +661,11 @@ async def async_save(
# Cannot rely solely on the metadata file existing pre-empted saves may be
# misclassified as partial saves.
partial_save = (
await async_path.exists(directory / PYTREE_METADATA_FILE)
isinstance(directory, epath.Path)
and await async_path.exists(directory / PYTREE_METADATA_FILE)
# TODO: b/428711337 - Use method from v1/_src/partial/path.py instead.
and '.partial_save' in directory.parent.name
)

batch_requests_ready_time = time.time()
if partial_save:
serialize_ops, tree_memory_size, param_infos, save_args = (
Expand Down Expand Up @@ -1190,7 +1191,7 @@ async def _write_metadata_file(
async def _write_metadata_after_commits(
self,
commit_futures: List[future.Future],
checkpoint_dir: epath.Path,
checkpoint_dir: path_types.PathAwaitingCreation | epath.Path,
*,
param_infos: PyTree,
save_args: PyTree,
Expand All @@ -1205,6 +1206,9 @@ async def _write_metadata_after_commits(
for commit_future in commit_futures:
await asyncio.to_thread(commit_future.result)

if isinstance(checkpoint_dir, path_types.PathAwaitingCreation):
checkpoint_dir = await checkpoint_dir.await_creation()

commit_time = time.time()
# `write_shape` is extracted from ArrayMetadata store saved during
# materialization of commit_futures. Then it is written to the pytree
Expand Down
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/_src/path/atomicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.path import types as path_types
from orbax.checkpoint._src.path import utils
from orbax.checkpoint._src.path.snapshot import snapshot as snapshot_lib
from orbax.checkpoint.experimental.v1._src.path import types as path_types


ValidationError = atomicity_types.ValidationError
Expand Down
69 changes: 69 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Types for path-related constructs."""

from __future__ import annotations

import typing
from typing import Protocol

from etils import epath


Path = epath.Path
PathLike = Path | str


@typing.runtime_checkable
class PathAwaitingCreation(Protocol):
"""A path that may not exist yet, but will exist after `await_creation`.

This construct is used to represent a path in the process of being created.
The underlying path can be accessed logically, but the actual location in
the filesystem should not be accessed until :py:meth:`.await_creation` is
called.

Usage::

path: :py:class:`.PathAwaitingCreation` = ...
# Logical accesses are OK.
print(path.path)
# Block until the path is known to exist.
path = await path.await_creation()
path.exists() # True.
"""

def __truediv__(self, other: PathLike) -> PathAwaitingCreation:
...

@property
def path(self) -> Path:
...

async def await_creation(self) -> Path:
"""Waits for the directory to be created.

This is a blocking operation, though it should return immediately if the
path has already been created. Be cautious about where this method is
called, since implementations may trigger the creation if `await_creation`
is called before the directory creation has been triggered.

It is recommended to only call this from background awaitables, or to
delay as long as possible.

Returns:
The path that was created.
"""
...
58 changes: 6 additions & 52 deletions checkpoint/orbax/checkpoint/experimental/v1/_src/path/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,10 @@

"""Types for path-related constructs."""

from __future__ import annotations
# pylint: disable=g-importing-member, g-multiple-import, unused-import, g-bad-import-order

import typing
from typing import Protocol

from etils import epath


Path = epath.Path
PathLike = Path | str


@typing.runtime_checkable
class PathAwaitingCreation(Protocol):
"""A path that may not exist yet, but will exist after `await_creation`.

This construct is used to represent a path in the process of being created.
The underlying path can be accessed logically, but the actual location in
the filesystem should not be accessed until :py:meth:`.await_creation` is
called.

Usage::

path: :py:class:`.PathAwaitingCreation` = ...
# Logical accesses are OK.
print(path.path)
# Block until the path is known to exist.
path = await path.await_creation()
path.exists() # True.
"""

def __truediv__(self, other: PathLike) -> PathAwaitingCreation:
...

@property
def path(self) -> Path:
...

async def await_creation(self) -> Path:
"""Waits for the directory to be created.

This is a blocking operation, though it should return immediately if the
path has already been created. Be cautious about where this method is
called, since implementations may trigger the creation if `await_creation`
is called before the directory creation has been triggered.

It is recommended to only call this from background awaitables, or to
delay as long as possible.

Returns:
The path that was created.
"""
...
from orbax.checkpoint._src.path.types import (
Path,
PathLike,
PathAwaitingCreation,
)
Loading