From f1544d1dfd9395c3907e86af288a5dad3ea81a11 Mon Sep 17 00:00:00 2001 From: Marco Berlot Date: Wed, 18 Feb 2026 06:28:25 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 871839606 --- .../_src/handlers/async_checkpoint_handler.py | 60 ++++++++++++++++--- .../base_pytree_checkpoint_handler.py | 18 +++--- 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/handlers/async_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/async_checkpoint_handler.py index a1f766db6..09d942f1d 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/async_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/async_checkpoint_handler.py @@ -20,27 +20,69 @@ from etils import epath from orbax.checkpoint._src.futures import future from orbax.checkpoint._src.handlers import checkpoint_handler +from orbax.checkpoint._src.path import types as path_types class AsyncCheckpointHandler(checkpoint_handler.CheckpointHandler): - """An interface providing async methods that can be used with CheckpointHandler.""" + """An interface providing async methods used with AsyncCheckpointer.""" @abc.abstractmethod async def async_save( - self, directory: epath.Path, *args, **kwargs + self, + directory: epath.Path, + *args, + **kwargs, ) -> Optional[List[future.Future]]: - """Constructs a save operation. + """Saves the given item to the provided directory. - Synchronously awaits a copy of the item, before returning commit futures - necessary to save the item. + Args: + directory: the directory to save to. + *args: additional arguments for save. + **kwargs: additional arguments for save. + + Returns: + A list of commit futures which can be awaited upon to complete the save + operation. + """ + pass + + +class DeferredPathAsyncCheckpointHandler(AsyncCheckpointHandler): + """Handler interface that receives Path or PathAwaitingCreation. - Note: Any operations on directory should be done by using - `future.CommitFutureAwaitingContractedSignals` to wait for directories to be - created. + This interface extends AsyncCheckpointHandler with an async_save method that + accepts either an epath.Path or PathAwaitingCreation, allowing handlers to + work with deferred paths (e.g., TFHub) where the actual path is allocated + asynchronously. + + Handlers implementing this interface can: + 1. Receive a deferred path representation before the path is allocated + 2. Wait for STEP_DIRECTORY_CREATION signal inside their CommitFuture + 3. Access the path via await_creation() or .path after the signal + """ + + @abc.abstractmethod + async def async_save( + self, + directory: epath.Path | path_types.PathAwaitingCreation, + *args, + **kwargs, + ) -> Optional[List[future.Future]]: + """Constructs a save operation with support for deferred paths. + + This method accepts an epath.Path or PathAwaitingCreation. + When a deferred path is passed, handler coroutines should wait for the + STEP_DIRECTORY_CREATION signal before accessing the path. Args: - directory: the directory to save to. + directory: The directory to save to. May be an epath.Path or + PathAwaitingCreation. For deferred paths, await_creation() or signal + ordering ensures the path is available. *args: additional arguments for save. **kwargs: additional arguments for save. + + Returns: + A list of futures that will commit the data when awaited. """ + pass diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index 936deaa8c..b5e5b8457 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -48,6 +48,7 @@ from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import async_path from orbax.checkpoint._src.path import format_utils +from orbax.checkpoint._src.path import types as path_types from orbax.checkpoint._src.serialization import limits from orbax.checkpoint._src.serialization import ocdbt_utils from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils @@ -308,7 +309,7 @@ def _format_bytes(bytes_value: Optional[int]) -> str: class BasePyTreeCheckpointHandler( - async_checkpoint_handler.AsyncCheckpointHandler + async_checkpoint_handler.DeferredPathAsyncCheckpointHandler ): """A CheckpointHandler implementation for any PyTree structure. @@ -433,7 +434,7 @@ def get_param_names(self, item: PyTree) -> PyTree: def _get_param_infos( self, item: PyTree, - directory: epath.Path, + directory: epath.Path | path_types.PathAwaitingCreation, *, use_ocdbt: bool = True, use_compression: bool | None = True, @@ -581,7 +582,7 @@ def _handle_diffs(keypath, diff): async def async_save( self, - directory: epath.Path, + directory: epath.Path | path_types.PathAwaitingCreation, args: BasePyTreeSaveArgs, ) -> Optional[List[future.Future]]: """Saves a PyTree to a given directory. @@ -644,7 +645,7 @@ async def async_save( use_zarr3=self._use_zarr3, ) assert all( - leaf.parent_dir == directory for leaf in jax.tree.leaves(param_infos) + leaf.parent_dir is directory for leaf in jax.tree.leaves(param_infos) ) serialize_ops = [] # List of (coros -> List of futures) @@ -660,11 +661,11 @@ async def async_save( # Cannot rely solely on the metadata file existing pre-empted saves may be # misclassified as partial saves. partial_save = ( - await async_path.exists(directory / PYTREE_METADATA_FILE) + isinstance(directory, epath.Path) + and await async_path.exists(directory / PYTREE_METADATA_FILE) # TODO: b/428711337 - Use method from v1/_src/partial/path.py instead. and '.partial_save' in directory.parent.name ) - batch_requests_ready_time = time.time() if partial_save: serialize_ops, tree_memory_size, param_infos, save_args = ( @@ -1190,7 +1191,7 @@ async def _write_metadata_file( async def _write_metadata_after_commits( self, commit_futures: List[future.Future], - checkpoint_dir: epath.Path, + checkpoint_dir: path_types.PathAwaitingCreation | epath.Path, *, param_infos: PyTree, save_args: PyTree, @@ -1205,6 +1206,9 @@ async def _write_metadata_after_commits( for commit_future in commit_futures: await asyncio.to_thread(commit_future.result) + if isinstance(checkpoint_dir, path_types.PathAwaitingCreation): + checkpoint_dir = await checkpoint_dir.await_creation() + commit_time = time.time() # `write_shape` is extracted from ArrayMetadata store saved during # materialization of commit_futures. Then it is written to the pytree