diff --git a/init2winit/checkpoint.py b/init2winit/checkpoint.py index 49ed335f..2affecc5 100644 --- a/init2winit/checkpoint.py +++ b/init2winit/checkpoint.py @@ -18,13 +18,15 @@ This is useful for training neural networks with stax, where model parameters are nested numpy arrays. """ + from absl import flags from absl import logging +from init2winit.dataset_lib import data_utils import jax -# pylint: disable=g-importing-member -from jax.experimental.multihost_utils import process_allgather +from jax.experimental import multihost_utils import orbax.checkpoint as ocp + FLAGS = flags.FLAGS @@ -49,7 +51,8 @@ def maybe_restore_checkpoint( unreplicated_batch_stats, unreplicated_training_metrics_state, orbax_checkpoint_manager=None, - orbax_checkpoint_manager_external=None): + orbax_checkpoint_manager_external=None, +): """Optionally restores from a checkpoint. The checkpoint logic is as follows: if `orbax_checkpoint_manager` contains @@ -77,9 +80,16 @@ def maybe_restore_checkpoint( in train_dir. """ uninitialized_global_step = -1 + # Unwrap CpuOffloaded leaves before passing to Orbax — it only accepts + # numpy/jax arrays as target leaves. The training algorithm's + # restore_optimizer_state() hook re-wraps them after restore. + unwrapped_optimizer_state = jax.tree.map( + lambda x: x.array if isinstance(x, data_utils.CpuOffloaded) else x, + unreplicated_optimizer_state, + ) unreplicated_checkpoint_state = dict( params=unreplicated_params, - optimizer_state=unreplicated_optimizer_state, + optimizer_state=unwrapped_optimizer_state, batch_stats=unreplicated_batch_stats, training_metrics_grabber=unreplicated_training_metrics_state, global_step=uninitialized_global_step, @@ -96,7 +106,7 @@ def maybe_restore_checkpoint( # train_dir does not exist or if it exists and contains no checkpoints. # Note that we could likely change the below line to: # found_checkpoint = latest_ckpt != unreplicated_checkpoint_state - found_checkpoint = (latest_ckpt['global_step'] != uninitialized_global_step) + found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step # If there's a latest checkpoint in the train_dir, restore from that. if found_checkpoint: @@ -123,7 +133,8 @@ def maybe_restore_checkpoint( 0, # global_step 0, # sum_train_cost 0, # preemption_count - False) # is_restored + False, + ) # is_restored else: # Else, don't restore from any checkpoint. return ( unreplicated_optimizer_state, @@ -133,7 +144,8 @@ def maybe_restore_checkpoint( 0, # global_step 0, # sum_train_cost 0, # preemption_count - False) # is_restored + False, + ) # is_restored return ( ckpt_to_return['optimizer_state'], @@ -143,7 +155,8 @@ def maybe_restore_checkpoint( ckpt_to_return['global_step'], # global_step ckpt_to_return['sum_train_cost'], ckpt_to_return['preemption_count'], # preemption_count - is_restored) # is_restored + is_restored, + ) # is_restored def unreplicate_and_save_checkpoint( @@ -154,38 +167,56 @@ def unreplicate_and_save_checkpoint( global_step, preemption_count, sum_train_cost, - orbax_checkpoint_manager): + orbax_checkpoint_manager, +): """Saves pytree, step, preemption_count, and sum_train_cost to train_dir.""" logging.info('Saving checkpoint to ckpt_%d', global_step) # jax.device_get doesn't work if jax.Array lives on multiple hosts. # So we first all_gather it to the host and then call jax.device_get if jax.process_count() > 1: - unreplicated_optimizer_state = jax.device_get( - process_allgather(optimizer_state, tiled=True)) - unreplicated_params = jax.device_get(process_allgather(params, tiled=True)) + unreplicated_optimizer_state = jax.tree.map( + lambda x: x + if isinstance(x, data_utils.CpuOffloaded) + else jax.device_get(multihost_utils.process_allgather(x, tiled=True)), + optimizer_state, + ) + unreplicated_params = jax.device_get( + multihost_utils.process_allgather(params, tiled=True) + ) else: - unreplicated_optimizer_state = jax.device_get(optimizer_state) + unreplicated_optimizer_state = jax.tree.map( + lambda x: x + if isinstance(x, data_utils.CpuOffloaded) + else jax.device_get(x), + optimizer_state, + ) unreplicated_params = jax.device_get(params) unreplicated_batch_stats = jax.device_get(batch_stats) - unreplicated_training_metrics_state = jax.device_get( - training_metrics_state) + unreplicated_training_metrics_state = jax.device_get(training_metrics_state) unreplicated_sum_train_cost = jax.device_get(sum_train_cost) - state = dict(global_step=global_step, - preemption_count=preemption_count, - sum_train_cost=unreplicated_sum_train_cost, - optimizer_state=unreplicated_optimizer_state, - params=unreplicated_params, - batch_stats=unreplicated_batch_stats, - training_metrics_grabber=unreplicated_training_metrics_state) - save_checkpoint(global_step, - state, - orbax_checkpoint_manager=orbax_checkpoint_manager) + # Unwrap CpuOffloaded leaves to plain numpy arrays for Orbax serialization. + # CpuOffloaded is a runtime-only wrapper for sharding control; on disk the + # wrapped arrays are stored as regular numpy arrays. + unreplicated_optimizer_state = jax.tree.map( + lambda x: x.array if isinstance(x, data_utils.CpuOffloaded) else x, + unreplicated_optimizer_state, + ) + state = dict( + global_step=global_step, + preemption_count=preemption_count, + sum_train_cost=unreplicated_sum_train_cost, + optimizer_state=unreplicated_optimizer_state, + params=unreplicated_params, + batch_stats=unreplicated_batch_stats, + training_metrics_grabber=unreplicated_training_metrics_state, + ) + save_checkpoint( + global_step, state, orbax_checkpoint_manager=orbax_checkpoint_manager + ) logging.info('Done saving checkpoint.') -def save_checkpoint(step, - state, - orbax_checkpoint_manager): +def save_checkpoint(step, state, orbax_checkpoint_manager): """Saves checkpoint to train_dir. A list of checkpoints will be stored in train_dir/step. @@ -229,9 +260,10 @@ def load_latest_checkpoint(target=None, orbax_checkpoint_manager=None): """Loads the most recent checkpoint listed in train_dir. Args: - target: used for checkpointing, a pytree whose structure will be used - to structure the restored checkpoint data. + target: used for checkpointing, a pytree whose structure will be used to + structure the restored checkpoint data. orbax_checkpoint_manager: An orbax.CheckpointManager instance. + Returns: The state restored from the checkpoint. If using Flax checkpointing and target=None, this will return a unstructured dictionary containing the diff --git a/init2winit/dataset_lib/data_utils.py b/init2winit/dataset_lib/data_utils.py index 5f0ec34d..26882eb0 100644 --- a/init2winit/dataset_lib/data_utils.py +++ b/init2winit/dataset_lib/data_utils.py @@ -30,13 +30,15 @@ import jraph import numpy as np - -Dataset = collections.namedtuple('Dataset', [ - 'train_iterator_fn', - 'eval_train_epoch', - 'valid_epoch', - 'test_epoch', -]) +Dataset = collections.namedtuple( + 'Dataset', + [ + 'train_iterator_fn', + 'eval_train_epoch', + 'valid_epoch', + 'test_epoch', + ], +) def log_rss(msg: str): @@ -45,8 +47,9 @@ def log_rss(msg: str): logging.info('%s — RSS: %.1f MB', msg, rss_mb) -def prefetch_iterator(source_iter: Iterator[jax.typing.ArrayLike], - num_prefetch: int) -> Iterator[jax.typing.ArrayLike]: +def prefetch_iterator( + source_iter: Iterator[jax.typing.ArrayLike], num_prefetch: int +) -> Iterator[jax.typing.ArrayLike]: """Wraps the given iterator with prefetching. Args: @@ -121,14 +124,16 @@ def iterator_as_numpy(iterator): yield jax.tree.map(lambda y: y._numpy(), x) # pylint: disable=protected-access -def image_iterator(data, - rescale, - output_shape, - is_one_hot, - autoencoder, - shuffle_rng=None, - augment_fn=None, - include_example_keys=False): +def image_iterator( + data, + rescale, + output_shape, + is_one_hot, + autoencoder, + shuffle_rng=None, + augment_fn=None, + include_example_keys=False, +): """Preprocesses the batch data arrays in the data generator. Rescales inputs. One hot encode targets if is_one_hot is true. @@ -166,11 +171,13 @@ def image_iterator(data, yield {'inputs': inputs, 'targets': targets} -def maybe_pad_batch(batch, - desired_batch_size, - data_format=None, - mask_key=None, - padding_value=0.0): +def maybe_pad_batch( + batch, + desired_batch_size, + data_format=None, + mask_key=None, + padding_value=0.0, +): """Zero pad the batch on the right to desired_batch_size. All keys in the batch dictionary will have their corresponding arrays padded. @@ -187,9 +194,9 @@ def maybe_pad_batch(batch, dimension to pad. If not provided then it is assumed the first dimension is the batch dimension. mask_key: Typically used for text datasets, it's either 'inputs' (for - encoder only models like language models) or 'targets' - (for encoder-decoder models like seq2seq tasks) to decide weights for - padded sequence. For Image datasets, this will be (most likely) unused. + encoder only models like language models) or 'targets' (for + encoder-decoder models like seq2seq tasks) to decide weights for padded + sequence. For Image datasets, this will be (most likely) unused. padding_value: value to be used as padding. Returns: @@ -247,7 +254,7 @@ def make_global_array(local_data, mesh): """Util to combine per-host batches into a global batch array. Args: - local_data: local data batch on host. + local_data: local data batch on host. mesh: mesh specification to shard the data. Returns: @@ -265,13 +272,46 @@ def make_global_array(local_data, mesh): return global_array +class CpuOffloaded: + """Marker wrapper for arrays that should remain on CPU. + + Wraps a numpy array to signal to the trainer's sharding and checkpoint + code that this leaf should be skipped during JAX sharding operations + and device transfers. Used by optimizers that offload state to host + memory (e.g., single-worker DiLoCo's slow_params and nesterov_b). + + The wrapped array is accessible via the `array` attribute. + """ + + def __init__(self, array): + self.array = array + + @property + def shape(self): + return self.array.shape + + @property + def dtype(self): + return self.array.dtype + + def __repr__(self): + return f'CpuOffloaded(shape={self.shape}, dtype={self.dtype})' + + def shard_pytree(pytree, mesh, shardings=None): + """Shards a pytree with the given shardings and mesh.""" + if shardings is None: shardings = nn.get_sharding(pytree, mesh) + + def _maybe_shard(arr, sharding): + """Shards the given array if the sharding is not None.""" + if sharding is None: + return arr + return jax.make_array_from_process_local_data(sharding, arr, arr.shape) + pytree = jax.tree_util.tree_map( - lambda arr, sharding: jax.make_array_from_process_local_data( - sharding, arr, arr.shape - ), + _maybe_shard, pytree, shardings, ) diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index 32d3d7a3..7dbb4132 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -653,6 +653,13 @@ def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng): logging.info( 'Checkpoint restored in %f seconds', time.time() - start_time ) + # Allow training algorithms to post-process restored optimizer state + # (e.g., re-wrap CpuOffloaded leaves stripped during serialization). + unreplicated_optimizer_state = ( + self.training_algorithm.restore_optimizer_state( + unreplicated_optimizer_state + ) + ) start_time = time.time() ( self._params, diff --git a/init2winit/trainer_lib/trainer.py b/init2winit/trainer_lib/trainer.py index c35c39b1..272b4fa4 100644 --- a/init2winit/trainer_lib/trainer.py +++ b/init2winit/trainer_lib/trainer.py @@ -105,10 +105,18 @@ def shard( _, params = data_utils.shard_pytree( unreplicated_params, self._mesh, params_sharding ) + + def _get_sharding(x): + """Returns the sharding for the given pytree node.""" + if isinstance(x, data_utils.CpuOffloaded): + return None + elif isinstance(x, jax.Array) and isinstance(x.sharding, NamedSharding): + return x.sharding + else: + return NamedSharding(self._mesh, P()) + optimizer_state_sharding = jax.tree_util.tree_map( - lambda x: x.sharding - if isinstance(x.sharding, NamedSharding) - else NamedSharding(self._mesh, P()), + _get_sharding, unreplicated_optimizer_state, ) diff --git a/init2winit/trainer_lib/training_algorithm.py b/init2winit/trainer_lib/training_algorithm.py index 4f07383b..cda43631 100644 --- a/init2winit/trainer_lib/training_algorithm.py +++ b/init2winit/trainer_lib/training_algorithm.py @@ -197,6 +197,21 @@ def init_optimizer_state( Optimizer state: Pytree of optimizer state. """ + def restore_optimizer_state(self, optimizer_state): + """Post-processes optimizer state after checkpoint restore. + + Override this method in subclasses that use runtime wrappers (e.g., + CpuOffloaded) which are stripped during serialization and need to be + re-applied after deserialization. + + Args: + optimizer_state: The restored optimizer state pytree (plain numpy arrays). + + Returns: + The post-processed optimizer state, ready for sharding. + """ + return optimizer_state + # Per-optimizer default opt_hparams for OptaxTrainingAlgorithm. # These consolidate all the inline defaults from get_optimizer() in