fix(fsdp engine): localize DTensor norm output for Qwen models in TP#1365
Open
HT-Yuan wants to merge 1 commit into
Open
fix(fsdp engine): localize DTensor norm output for Qwen models in TP#1365HT-Yuan wants to merge 1 commit into
HT-Yuan wants to merge 1 commit into
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request introduces support for localizing DTensor outputs in Qwen models during tensor parallelization. It adds a _localize_dtensor_output hook and logic to register it on the model's final normalization layer, which is necessary for models where the path between the norm and the language head contains operations incompatible with DTensors. A review comment identifies a redundant type check for the model backbone that should be removed to simplify the code.
Comment on lines
+376
to
+378
| if not isinstance(model.model, nn.Module): | ||
| raise RuntimeError("Model does not have the required submodule 'model'.") | ||
| backbone = model.model |
Contributor
Qwen models have intermediate ops (aten.alias, aten.slice) between the final norm and lm_head that break DTensor dispatch under tensor parallelism. This commit: - Adds is_qwen_model() helper to identify Qwen model family. - Registers a forward hook on the final norm to redistribute its DTensor output to Replicate and convert to a local tensor. - Adjusts lm_head/score input_layouts to Replicate() for Qwen models so the downstream linear layers receive plain tensors. - Extracts backbone variable to avoid redundant attribute access. Without this fix, Qwen models crash with DTensor dispatch errors when running with TP > 1.
16b4658 to
f1f3ae8
Compare
sitabulaixizawaluduo
approved these changes
May 25, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Qwen models have intermediate ops (aten.alias, aten.slice) between the final norm and lm_head that break DTensor dispatch under tensor parallelism. This commit:
Without this fix, Qwen models crash with DTensor dispatch errors when running with TP > 1.
Changes
areal/engine/core/model.py: Addis_qwen_model()utility function.areal/engine/fsdp_utils/parallel.py:_localize_dtensor_outputhook function.head_input_layoutbased on model type.parallelize_module.backbonefor clarity and add type-checking guards.Related Issue
Fixes #1366
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prBreaking Change Details (if applicable):
Additional Context
Need help? Check the Contributing Guide or ask in
GitHub Discussions!