From 93e21fb4ff89597fa4724b7c9469b7a2f021ddb2 Mon Sep 17 00:00:00 2001 From: youyi Date: Fri, 15 May 2026 15:06:24 +0800 Subject: [PATCH 1/2] fix(checkpointer): use dp_reshardable sharding type for megatron-core >=0.11 megatron-core >=0.11 removed flattened_range support in ShardedTensor.validate_metadata_integrity(), but the default sharding type (fully_sharded_model_space) still sets flattened_range, causing save/load to fail. Switch to dp_reshardable which does not rely on flattened_range. Co-Authored-By: Claude Opus 4.6 (1M context) --- areal/engine/megatron_utils/checkpointer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_utils/checkpointer.py b/areal/engine/megatron_utils/checkpointer.py index 78a32867f5..383bdd877d 100644 --- a/areal/engine/megatron_utils/checkpointer.py +++ b/areal/engine/megatron_utils/checkpointer.py @@ -280,7 +280,15 @@ def generate_state_dict( # Optimizer State Dict if with_optimizer: torch.distributed.barrier() - optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict) + # megatron-core >=0.11 removed flattened_range support in + # ShardedTensor.validate_metadata_integrity(), but the default + # sharding type (fully_sharded_model_space) still sets + # flattened_range, causing save/load to fail. Use + # dp_reshardable which does not rely on flattened_range. + optimizer_sharded_states = self.optimizer.sharded_state_dict( + state_dict, + metadata={"distrib_optim_sharding_type": "dp_reshardable"}, + ) state_dict["optimizer"] = optimizer_sharded_states if self.lr_scheduler is not None: From cf37fdd96fe56e2a61d2c14f2590febb6aa79fec Mon Sep 17 00:00:00 2001 From: youyi Date: Mon, 18 May 2026 19:41:32 +0800 Subject: [PATCH 2/2] refactor(engine): expose distrib_optim_sharding_type as config Move the hard-coded 'dp_reshardable' sharding type into MegatronEngineConfig so users with legacy checkpoints saved under 'fully_sharded_model_space' can load them by flipping the config instead of patching the source. Key changes: - Add distrib_optim_sharding_type field to MegatronEngineConfig (default 'dp_reshardable') - Plumb the value through MegatronCheckpointManager and use it in generate_state_dict instead of the hard-coded string --- areal/api/cli_args.py | 11 +++++++++++ areal/engine/megatron_engine.py | 1 + areal/engine/megatron_utils/checkpointer.py | 6 +++++- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index ea55f557a8..34bae678ee 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -874,6 +874,17 @@ class MegatronEngineConfig: # Checkpointing Configuration async_save: bool = False use_checkpoint_opt_param_scheduler: bool = True + distrib_optim_sharding_type: str = field( + default="dp_reshardable", + metadata={ + "help": ( + "Sharding type for distributed optimizer checkpoint. " + "'dp_reshardable' works with megatron-core >=0.11; set to " + "'fully_sharded_model_space' to load legacy checkpoints." + ), + "choices": ["dp_reshardable", "fully_sharded_model_space"], + }, + ) # Deterministic Option # NOTE: This option forces torch to use deterministic algorithms, diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index a512469bc0..57176265be 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1358,6 +1358,7 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None: use_distributed_optimizer=use_distributed_optimizer, use_checkpoint_opt_param_scheduler=self.mcore_config.use_checkpoint_opt_param_scheduler, async_save=self.mcore_config.async_save, + distrib_optim_sharding_type=self.mcore_config.distrib_optim_sharding_type, ) def _check_rollout_engine_connected(self) -> None: diff --git a/areal/engine/megatron_utils/checkpointer.py b/areal/engine/megatron_utils/checkpointer.py index 383bdd877d..dc1e382ae8 100644 --- a/areal/engine/megatron_utils/checkpointer.py +++ b/areal/engine/megatron_utils/checkpointer.py @@ -147,6 +147,7 @@ def __init__( use_checkpoint_opt_param_scheduler: bool = False, use_dist_checkpointing: bool = True, async_save: bool = False, + distrib_optim_sharding_type: str = "dp_reshardable", ): self.model = model self.optimizer = optimizer @@ -160,6 +161,7 @@ def __init__( self.rank = torch.distributed.get_rank() self.use_dist_checkpointing = use_dist_checkpointing self.async_save = async_save + self.distrib_optim_sharding_type = distrib_optim_sharding_type if async_save: raise NotImplementedError("Async save not implenmented yet!") @@ -287,7 +289,9 @@ def generate_state_dict( # dp_reshardable which does not rely on flattened_range. optimizer_sharded_states = self.optimizer.sharded_state_dict( state_dict, - metadata={"distrib_optim_sharding_type": "dp_reshardable"}, + metadata={ + "distrib_optim_sharding_type": self.distrib_optim_sharding_type + }, ) state_dict["optimizer"] = optimizer_sharded_states