[megatron] fix: zero out mtp_num_layers and trim csa_compress_ratios on vanilla_mbridge=True path#6515
Conversation
There was a problem hiding this comment.
Code Review
This pull request removes the NPU MTP inference patch logic and introduces a fix for the vanilla MBridge path when MTP is disabled by zeroing out mtp_num_layers and trimming csa_compress_ratios. The reviewer noted that removing the NPU MTP patch entirely breaks NPU support and suggested using a try-except block to safely import the Megatron version instead. Additionally, the reviewer recommended using direct assignment instead of setdefault to guarantee that mtp_num_layers is correctly overridden and set to 0 when MTP is disabled.
| from verl.utils.device import is_cuda_available | ||
|
|
There was a problem hiding this comment.
Completely removing the NPU MTP inference patch and the is_npu_available check breaks MTP inference support on NPU platforms. If the removal of __version__ from megatron.core.package_info was intended to prevent import errors on certain Megatron versions, we should instead import it safely using a try-except block locally, rather than dropping NPU support entirely.
| from verl.utils.device import is_cuda_available | |
| from verl.utils.device import is_cuda_available, is_npu_available | |
| try: | |
| from megatron.core.package_info import __version__ | |
| except ImportError: | |
| __version__ = "0.0.0" | |
| if is_npu_available and __version__ >= "0.16.0": | |
| from verl.models.mcore.patch import apply_mtp_inference_patch | |
| apply_mtp_inference_patch() |
There was a problem hiding this comment.
Why do you choose to remove the logic here? I think this change is unrelated to this PR unless you have tested this on NPU (I don't really have access to NPU clusters so I can't say the changes here are reasonable or not) @Meirtz
| if not self.model_config.mtp.enable: | ||
| override_transformer_config.setdefault("mtp_num_layers", 0) |
There was a problem hiding this comment.
Using setdefault here will not override "mtp_num_layers" if it is already present in override_transformer_config (e.g., if set by user configuration or previous steps). To guarantee that MTP is completely disabled when self.model_config.mtp.enable is False, we should explicitly set "mtp_num_layers" to 0 using direct assignment.
| if not self.model_config.mtp.enable: | |
| override_transformer_config.setdefault("mtp_num_layers", 0) | |
| if not self.model_config.mtp.enable: | |
| override_transformer_config["mtp_num_layers"] = 0 |
There was a problem hiding this comment.
The changes LGTM and this can be merged before #6473, but I'm curious about why do you decide to remove that NPU mtp patch
| from verl.utils.device import is_cuda_available | ||
|
|
There was a problem hiding this comment.
Why do you choose to remove the logic here? I think this change is unrelated to this PR unless you have tested this on NPU (I don't really have access to NPU clusters so I can't say the changes here are reasonable or not) @Meirtz
dd451b7 to
7d1128c
Compare
|
Thanks @HollowMan6 — not intentional, the NPU patch got dropped while the branch was on an older base. Rebased onto latest main so it's fully restored, and the diff is now just the vanilla-mbridge fix. Also applied @gemini-code-assist's suggestion to use direct assignment for |
…on vanilla_mbridge=True path PR verl-project#6473 added the same fix to the vanilla_mbridge=False (NeMo MB) path of MegatronEngine._build_tf_config. The vanilla_mbridge=True (ISEEKYAN/mbridge) path needs the symmetric treatment: when self.model_config.mtp.enable is False, force mtp_num_layers=0 so the bridge does not build MTP blocks, and trim the per-layer csa_compress_ratios list (DSv4-Flash HF configs pad it for the MTP layer when num_nextn_predict_layers > 0). mtp_num_layers uses direct assignment (not setdefault) so a disabled-MTP run always forces 0 even if override_transformer_config carried a stale value. Why not duplicate: verl-project#6473 only modifies the vanilla_mbridge=False branch. This PR modifies the vanilla_mbridge=True branch — different code path, complementary fix. Test plan: validated end-to-end on GB200 (1 GPU) through ISEEKYAN/mbridge + DSv4 hybrid attention — forward + backward + optimizer.step() with the vanilla=True path produces finite loss / finite grad_norm / update_successful=True. AI assistance disclosure: developed with AI-assisted coding (Claude); author reviewed every changed line. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Lingrui Mei <lmei@nvidia.com>
7d1128c to
e48aadc
Compare
Summary
MegatronEngine._build_tf_config(verl/workers/engine/megatron/transformer_impl.py)has two branches keyed off
vanilla_mbridge:vanilla_mbridge=False(NeMo Megatron-Bridge path) —transformer_impl.py:~190+vanilla_mbridge=True(ISEEKYAN/mbridge path) —transformer_impl.py:~181+PR #6473 added the MTP-disable +
csa_compress_ratiostrim fix to thevanilla_mbridge=Falsebranch. This PR applies the symmetric fix to thevanilla_mbridge=Truebranch.Why this is not a duplicate of #6473
#6473 modifies the
vanilla_mbridge=Falsebranch (NeMo MBto_megatron_provider()apply_overrides_and_finalizepath). This PR modifies thevanilla_mbridge=Truebranch (ISEEKYAN/mbridge
AutoBridge.from_config+set_extra_argspath). Samefix idea, complementary code path. Searched
gh pr list --search "vanilla_mbridge mtp_num_layers in:body"— only this PR.The bug
Without this fix on the
vanilla_mbridge=Truepath:bridge.set_extra_args(**override_transformer_config)rebuildsMLATransformerConfigfrom the HF config.mtp_num_layersfromnum_nextn_predict_layersand padscsa_compress_ratiostonum_layers + mtp_num_layers.num_nextn_predict_layers > 0, so callers thatdisable MTP at runtime via
model_config.mtp.enable=Falsestill getmtp_num_layers > 0andlen(csa_compress_ratios) > num_layers.The fix
Same pattern as #6473, two-stage because
set_extra_argsrebuilds the config:Test plan
End-to-end validated on a single GB200 GPU through the ISEEKYAN/mbridge
DSv4 path:
verl.models.mcore.mbridge.AutoBridge→bridge.get_model():loss=0.000144finite.optimizer.step()) viaverl.utils.megatron.optimizer.get_megatron_optimizer(bf16 Adam),exercising window attention, CSA, DSA indexer, mHC, hash MoE, and MTP layer:
loss=0.000170,grad_norm=0.135386,update_successful=True.The MTP-disable conditional and trim arithmetic are identical to #6473's
vanilla_mbridge=Falsefix — same fix is being exercised on every DSv4attention path that goes through
_build_tf_config.AI assistance disclosure
This change was developed with AI-assisted coding (Claude). The author has
reviewed every changed line and personally executed the test plan above.
Companion PR (defensive bridge-side default that complements this engine-side
fix): NVIDIA-NeMo/Megatron-Bridge#4003.