diff --git a/areal/engine/core/model.py b/areal/engine/core/model.py index 6cab72403b..1000409418 100644 --- a/areal/engine/core/model.py +++ b/areal/engine/core/model.py @@ -59,6 +59,10 @@ def is_gemma3_model(model_type: str) -> bool: return model_type in ["gemma3"] +def is_qwen_model(model_type: str) -> bool: + return model_type.startswith("qwen") + + VALID_MOE_MODELS = [ "qwen3_moe", "qwen3_vl_moe", diff --git a/areal/engine/fsdp_utils/parallel.py b/areal/engine/fsdp_utils/parallel.py index b5aabf57d3..2abb82d1b7 100644 --- a/areal/engine/fsdp_utils/parallel.py +++ b/areal/engine/fsdp_utils/parallel.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +from typing import cast import torch from torch import nn from torch.distributed import ProcessGroup from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy -from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor import DTensor, Replicate, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, ParallelStyle, @@ -24,6 +25,7 @@ is_gemma3_model, is_moe_model, is_qwen3_vl_model, + is_qwen_model, is_valid_vision_model, ) from areal.engine.fsdp_utils import apply_fsdp2 @@ -216,6 +218,17 @@ def seq_len_divisor(self) -> int: return self._ps.tp_size * self._ps.cp_size +def _localize_dtensor_output( + _module: nn.Module, _inputs: tuple[object, ...], output: object +) -> object: + if not isinstance(output, DTensor): + return output + + dtensor_output = cast(DTensor, output) + placements = tuple(Replicate() for _ in dtensor_output.placements) + return dtensor_output.redistribute(placements=placements).to_local() + + def apply_non_moe_tp( model: nn.Module, model_config: PretrainedConfig, @@ -302,27 +315,33 @@ def apply_non_moe_tp( } ) + # Qwen models: norm→lm_head path has DTensor-incompatible ops, so we localize + # norm output to Replicate; lm_head input layout must match accordingly + use_local_final_norm_output = is_qwen_model(model_config.model_type) + head_input_layout = Replicate() if use_local_final_norm_output else Shard(1) + # For root module root_tp_plan: dict[str, ParallelStyle] = {} if hasattr(model, "lm_head") and isinstance(model.lm_head, nn.Module): - # Implicitly all-gather in ColwiseParallel - # Output is sharded on the last dimension (Shard(2)) + # Implicitly all-gather in ColwiseParallel when the input is Shard(1). + # Output is sharded on the last dimension (Shard(2)). root_tp_plan["lm_head"] = ColwiseParallel( - input_layouts=Shard(1), + input_layouts=head_input_layout, ) if hasattr(model, "score") and isinstance(model.score, nn.Module): # For PPO's critic model's score layer: - # 1. The input is sharded by sequence parallelism (Shard(1)) + # 1. The input follows the final norm output layout # 2. `score` is a linear layer with replicated weights # 3. All-gather the output along the sequence dimension to get the full results root_tp_plan["score"] = ReplicateParallel( - input_layout=Shard(1), - desired_input_layout=Shard(1), + input_layout=head_input_layout, + desired_input_layout=head_input_layout, output_layout=Replicate(), ) if is_valid_vision_model(model_config.model_type): if isinstance(model.model.language_model, nn.Module): + backbone = model.model.language_model # For vision-language models, avoid sharding the embedding layer because # the visual components access it without tensor parallelism support. # Instead, configure the first transformer layer to handle input @@ -342,10 +361,10 @@ def apply_non_moe_tp( patch_qwen3_vl_deepstack_process_for_tp, ) - patch_qwen3_vl_deepstack_process_for_tp(model.model.language_model) + patch_qwen3_vl_deepstack_process_for_tp(backbone) parallelize_module( - model.model.language_model, + backbone, device_mesh=tp_device_mesh, parallelize_plan=model_tp_plan, ) @@ -354,8 +373,9 @@ def apply_non_moe_tp( "Vision model does not have the required submodule 'model.language_model'" ) else: + backbone = model.model parallelize_module( - model.model, + backbone, device_mesh=tp_device_mesh, parallelize_plan=model_tp_plan, ) @@ -366,6 +386,18 @@ def apply_non_moe_tp( parallelize_plan=root_tp_plan, ) + # Register norm localization hook after parallelize_module so norm is already + # SequenceParallel. We reuse the `backbone` determined above to avoid + # re-deriving it. + if use_local_final_norm_output: + norm = getattr(backbone, "norm", None) + if not isinstance(norm, nn.Module): + raise RuntimeError( + f"Model backbone ({type(backbone).__name__}) does not have 'norm' submodule, " + f"but localized norm output is required for model_type={model_config.model_type!r}." + ) + cast(nn.Module, norm).register_forward_hook(_localize_dtensor_output) + def parallelize_model( model: nn.Module,