Skip to content

fix(fsdp engine): localize DTensor norm output for Qwen models in TP#1365

Open
HT-Yuan wants to merge 1 commit into
areal-project:mainfrom
HT-Yuan:fix/qwen-tp-dtensor-localize
Open

fix(fsdp engine): localize DTensor norm output for Qwen models in TP#1365
HT-Yuan wants to merge 1 commit into
areal-project:mainfrom
HT-Yuan:fix/qwen-tp-dtensor-localize

Conversation

@HT-Yuan
Copy link
Copy Markdown
Collaborator

@HT-Yuan HT-Yuan commented May 25, 2026

Description

Qwen models have intermediate ops (aten.alias, aten.slice) between the final norm and lm_head that break DTensor dispatch under tensor parallelism. This commit:

  • Adds is_qwen_model() helper to identify Qwen model family.
  • Registers a forward hook on the final norm to redistribute its DTensor output to Replicate and convert to a local tensor.
  • Adjusts lm_head/score input_layouts to Replicate() for Qwen models so the downstream linear layers receive plain tensors.
  • Extracts backbone variable to avoid redundant attribute access.

Without this fix, Qwen models crash with DTensor dispatch errors when running with TP > 1.

Changes

  • areal/engine/core/model.py: Add is_qwen_model() utility function.
  • areal/engine/fsdp_utils/parallel.py:
    • Add _localize_dtensor_output hook function.
    • Conditionally set head_input_layout based on model type.
    • Register hook on final norm after parallelize_module.
    • Extract backbone for clarity and add type-checking guards.

Related Issue

Fixes #1366

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (if applicable; built with ./docs/build_all.sh)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

Additional Context


Need help? Check the Contributing Guide or ask in
GitHub Discussions!

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for localizing DTensor outputs in Qwen models during tensor parallelization. It adds a _localize_dtensor_output hook and logic to register it on the model's final normalization layer, which is necessary for models where the path between the norm and the language head contains operations incompatible with DTensors. A review comment identifies a redundant type check for the model backbone that should be removed to simplify the code.

Comment thread areal/engine/fsdp_utils/parallel.py Outdated
Comment on lines +376 to +378
if not isinstance(model.model, nn.Module):
raise RuntimeError("Model does not have the required submodule 'model'.")
backbone = model.model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check is redundant because model.model is already verified to be an instance of nn.Module at the beginning of the apply_non_moe_tp function (lines 261-262). You can directly assign backbone = model.model here.

        backbone = model.model

Qwen models have intermediate ops (aten.alias, aten.slice) between
the final norm and lm_head that break DTensor dispatch under tensor
parallelism. This commit:

- Adds is_qwen_model() helper to identify Qwen model family.
- Registers a forward hook on the final norm to redistribute its
  DTensor output to Replicate and convert to a local tensor.
- Adjusts lm_head/score input_layouts to Replicate() for Qwen models
  so the downstream linear layers receive plain tensors.
- Extracts backbone variable to avoid redundant attribute access.

Without this fix, Qwen models crash with DTensor dispatch errors
when running with TP > 1.
@HT-Yuan HT-Yuan force-pushed the fix/qwen-tp-dtensor-localize branch from 16b4658 to f1f3ae8 Compare May 25, 2026 12:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Qwen models crash with DTensor dispatch error under TP > 1

2 participants