diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index ea55f557a8..34bae678ee 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -874,6 +874,17 @@ class MegatronEngineConfig: # Checkpointing Configuration async_save: bool = False use_checkpoint_opt_param_scheduler: bool = True + distrib_optim_sharding_type: str = field( + default="dp_reshardable", + metadata={ + "help": ( + "Sharding type for distributed optimizer checkpoint. " + "'dp_reshardable' works with megatron-core >=0.11; set to " + "'fully_sharded_model_space' to load legacy checkpoints." + ), + "choices": ["dp_reshardable", "fully_sharded_model_space"], + }, + ) # Deterministic Option # NOTE: This option forces torch to use deterministic algorithms, diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index a512469bc0..57176265be 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1358,6 +1358,7 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: use_distributed_optimizer=use_distributed_optimizer, use_checkpoint_opt_param_scheduler=self.mcore_config.use_checkpoint_opt_param_scheduler, async_save=self.mcore_config.async_save, + distrib_optim_sharding_type=self.mcore_config.distrib_optim_sharding_type, ) def _check_rollout_engine_connected(self) -> None: diff --git a/areal/engine/megatron_utils/checkpointer.py b/areal/engine/megatron_utils/checkpointer.py index 78a32867f5..dc1e382ae8 100644 --- a/areal/engine/megatron_utils/checkpointer.py +++ b/areal/engine/megatron_utils/checkpointer.py @@ -147,6 +147,7 @@ def __init__( use_checkpoint_opt_param_scheduler: bool = False, use_dist_checkpointing: bool = True, async_save: bool = False, + distrib_optim_sharding_type: str = "dp_reshardable", ): self.model = model self.optimizer = optimizer @@ -160,6 +161,7 @@ def __init__( self.rank = torch.distributed.get_rank() self.use_dist_checkpointing = use_dist_checkpointing self.async_save = async_save + self.distrib_optim_sharding_type = distrib_optim_sharding_type if async_save: raise NotImplementedError("Async save not implenmented yet!") @@ -280,7 +282,17 @@ def generate_state_dict( # Optimizer State Dict if with_optimizer: torch.distributed.barrier() - optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict) + # megatron-core >=0.11 removed flattened_range support in + # ShardedTensor.validate_metadata_integrity(), but the default + # sharding type (fully_sharded_model_space) still sets + # flattened_range, causing save/load to fail. Use + # dp_reshardable which does not rely on flattened_range. + optimizer_sharded_states = self.optimizer.sharded_state_dict( + state_dict, + metadata={ + "distrib_optim_sharding_type": self.distrib_optim_sharding_type + }, + ) state_dict["optimizer"] = optimizer_sharded_states if self.lr_scheduler is not None: