Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions areal/engine/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def is_gemma3_model(model_type: str) -> bool:
return model_type in ["gemma3"]


def is_qwen_model(model_type: str) -> bool:
return model_type.startswith("qwen")


VALID_MOE_MODELS = [
"qwen3_moe",
"qwen3_vl_moe",
Expand Down
52 changes: 42 additions & 10 deletions areal/engine/fsdp_utils/parallel.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import cast

import torch
from torch import nn
from torch.distributed import ProcessGroup
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy
from torch.distributed.tensor import Replicate, Shard
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
ParallelStyle,
Expand All @@ -24,6 +25,7 @@
is_gemma3_model,
is_moe_model,
is_qwen3_vl_model,
is_qwen_model,
is_valid_vision_model,
)
from areal.engine.fsdp_utils import apply_fsdp2
Expand Down Expand Up @@ -216,6 +218,17 @@ def seq_len_divisor(self) -> int:
return self._ps.tp_size * self._ps.cp_size


def _localize_dtensor_output(
_module: nn.Module, _inputs: tuple[object, ...], output: object
) -> object:
if not isinstance(output, DTensor):
return output

dtensor_output = cast(DTensor, output)
placements = tuple(Replicate() for _ in dtensor_output.placements)
return dtensor_output.redistribute(placements=placements).to_local()


def apply_non_moe_tp(
model: nn.Module,
model_config: PretrainedConfig,
Expand Down Expand Up @@ -302,27 +315,33 @@ def apply_non_moe_tp(
}
)

# Qwen models: norm→lm_head path has DTensor-incompatible ops, so we localize
# norm output to Replicate; lm_head input layout must match accordingly
use_local_final_norm_output = is_qwen_model(model_config.model_type)
head_input_layout = Replicate() if use_local_final_norm_output else Shard(1)

# For root module
root_tp_plan: dict[str, ParallelStyle] = {}
if hasattr(model, "lm_head") and isinstance(model.lm_head, nn.Module):
# Implicitly all-gather in ColwiseParallel
# Output is sharded on the last dimension (Shard(2))
# Implicitly all-gather in ColwiseParallel when the input is Shard(1).
# Output is sharded on the last dimension (Shard(2)).
root_tp_plan["lm_head"] = ColwiseParallel(
input_layouts=Shard(1),
input_layouts=head_input_layout,
)
if hasattr(model, "score") and isinstance(model.score, nn.Module):
# For PPO's critic model's score layer:
# 1. The input is sharded by sequence parallelism (Shard(1))
# 1. The input follows the final norm output layout
# 2. `score` is a linear layer with replicated weights
# 3. All-gather the output along the sequence dimension to get the full results
root_tp_plan["score"] = ReplicateParallel(
input_layout=Shard(1),
desired_input_layout=Shard(1),
input_layout=head_input_layout,
desired_input_layout=head_input_layout,
output_layout=Replicate(),
)

if is_valid_vision_model(model_config.model_type):
if isinstance(model.model.language_model, nn.Module):
backbone = model.model.language_model
# For vision-language models, avoid sharding the embedding layer because
# the visual components access it without tensor parallelism support.
# Instead, configure the first transformer layer to handle input
Expand All @@ -342,10 +361,10 @@ def apply_non_moe_tp(
patch_qwen3_vl_deepstack_process_for_tp,
)

patch_qwen3_vl_deepstack_process_for_tp(model.model.language_model)
patch_qwen3_vl_deepstack_process_for_tp(backbone)

parallelize_module(
model.model.language_model,
backbone,
device_mesh=tp_device_mesh,
parallelize_plan=model_tp_plan,
)
Expand All @@ -354,8 +373,9 @@ def apply_non_moe_tp(
"Vision model does not have the required submodule 'model.language_model'"
)
else:
backbone = model.model
parallelize_module(
model.model,
backbone,
device_mesh=tp_device_mesh,
parallelize_plan=model_tp_plan,
)
Expand All @@ -366,6 +386,18 @@ def apply_non_moe_tp(
parallelize_plan=root_tp_plan,
)

# Register norm localization hook after parallelize_module so norm is already
# SequenceParallel. We reuse the `backbone` determined above to avoid
# re-deriving it.
if use_local_final_norm_output:
norm = getattr(backbone, "norm", None)
if not isinstance(norm, nn.Module):
raise RuntimeError(
f"Model backbone ({type(backbone).__name__}) does not have 'norm' submodule, "
f"but localized norm output is required for model_type={model_config.model_type!r}."
)
cast(nn.Module, norm).register_forward_hook(_localize_dtensor_output)


def parallelize_model(
model: nn.Module,
Expand Down
Loading