Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,17 @@ class MegatronEngineConfig:
# Checkpointing Configuration
async_save: bool = False
use_checkpoint_opt_param_scheduler: bool = True
distrib_optim_sharding_type: str = field(
default="dp_reshardable",
metadata={
"help": (
"Sharding type for distributed optimizer checkpoint. "
"'dp_reshardable' works with megatron-core >=0.11; set to "
"'fully_sharded_model_space' to load legacy checkpoints."
),
"choices": ["dp_reshardable", "fully_sharded_model_space"],
},
)

# Deterministic Option
# NOTE: This option forces torch to use deterministic algorithms,
Expand Down
1 change: 1 addition & 0 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,7 @@ def _create_optimizer(self, ft_spec: FinetuneSpec) -> None:
use_distributed_optimizer=use_distributed_optimizer,
use_checkpoint_opt_param_scheduler=self.mcore_config.use_checkpoint_opt_param_scheduler,
async_save=self.mcore_config.async_save,
distrib_optim_sharding_type=self.mcore_config.distrib_optim_sharding_type,
)

def _check_rollout_engine_connected(self) -> None:
Expand Down
14 changes: 13 additions & 1 deletion areal/engine/megatron_utils/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
use_checkpoint_opt_param_scheduler: bool = False,
use_dist_checkpointing: bool = True,
async_save: bool = False,
distrib_optim_sharding_type: str = "dp_reshardable",
):
self.model = model
self.optimizer = optimizer
Expand All @@ -160,6 +161,7 @@ def __init__(
self.rank = torch.distributed.get_rank()
self.use_dist_checkpointing = use_dist_checkpointing
self.async_save = async_save
self.distrib_optim_sharding_type = distrib_optim_sharding_type
if async_save:
raise NotImplementedError("Async save not implenmented yet!")

Expand Down Expand Up @@ -280,7 +282,17 @@ def generate_state_dict(
# Optimizer State Dict
if with_optimizer:
torch.distributed.barrier()
optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict)
# megatron-core >=0.11 removed flattened_range support in
# ShardedTensor.validate_metadata_integrity(), but the default
# sharding type (fully_sharded_model_space) still sets
# flattened_range, causing save/load to fail. Use
# dp_reshardable which does not rely on flattened_range.
optimizer_sharded_states = self.optimizer.sharded_state_dict(
state_dict,
metadata={
"distrib_optim_sharding_type": self.distrib_optim_sharding_type
},
)
state_dict["optimizer"] = optimizer_sharded_states

if self.lr_scheduler is not None:
Expand Down