Skip to content

[megatron] fix: zero out mtp_num_layers and trim csa_compress_ratios on vanilla_mbridge=True path#6515

Open
Meirtz wants to merge 1 commit into
verl-project:mainfrom
Meirtz:fix/dsv4-mtp-disable-vanilla-mbridge
Open

[megatron] fix: zero out mtp_num_layers and trim csa_compress_ratios on vanilla_mbridge=True path#6515
Meirtz wants to merge 1 commit into
verl-project:mainfrom
Meirtz:fix/dsv4-mtp-disable-vanilla-mbridge

Conversation

@Meirtz
Copy link
Copy Markdown

@Meirtz Meirtz commented May 28, 2026

Summary

MegatronEngine._build_tf_config (verl/workers/engine/megatron/transformer_impl.py)
has two branches keyed off vanilla_mbridge:

  • vanilla_mbridge=False (NeMo Megatron-Bridge path) — transformer_impl.py:~190+
  • vanilla_mbridge=True (ISEEKYAN/mbridge path) — transformer_impl.py:~181+

PR #6473 added the MTP-disable + csa_compress_ratios trim fix to the
vanilla_mbridge=False branch. This PR applies the symmetric fix to the
vanilla_mbridge=True branch.

Why this is not a duplicate of #6473

#6473 modifies the vanilla_mbridge=False branch (NeMo MB to_megatron_provider()

  • apply_overrides_and_finalize path). This PR modifies the vanilla_mbridge=True
    branch (ISEEKYAN/mbridge AutoBridge.from_config + set_extra_args path). Same
    fix idea, complementary code path. Searched gh pr list --search "vanilla_mbridge mtp_num_layers in:body" — only this PR.

The bug

Without this fix on the vanilla_mbridge=True path:

  1. bridge.set_extra_args(**override_transformer_config) rebuilds
    MLATransformerConfig from the HF config.
  2. The bridge derives mtp_num_layers from num_nextn_predict_layers and pads
    csa_compress_ratios to num_layers + mtp_num_layers.
  3. DSv4-Flash HF configs ship num_nextn_predict_layers > 0, so callers that
    disable MTP at runtime via model_config.mtp.enable=False still get
    mtp_num_layers > 0 and len(csa_compress_ratios) > num_layers.
  4. The mismatch propagates into MTP-block construction and CSA-schedule indexing.

The fix

Same pattern as #6473, two-stage because set_extra_args rebuilds the config:

if not self.model_config.mtp.enable:
    override_transformer_config.setdefault("mtp_num_layers", 0)
bridge.set_extra_args(**override_transformer_config)
if not self.model_config.mtp.enable:
    csa = getattr(bridge.config, "csa_compress_ratios", None)
    num_layers = getattr(bridge.config, "num_layers", None)
    if csa is not None and num_layers is not None and len(csa) > num_layers:
        bridge.config.csa_compress_ratios = list(csa[:num_layers])

Test plan

End-to-end validated on a single GB200 GPU through the ISEEKYAN/mbridge
DSv4 path:

  • Forward-only smoke through verl.models.mcore.mbridge.AutoBridge
    bridge.get_model(): loss=0.000144 finite.
  • Full training step (forward + backward + optimizer.step()) via
    verl.utils.megatron.optimizer.get_megatron_optimizer (bf16 Adam),
    exercising window attention, CSA, DSA indexer, mHC, hash MoE, and MTP layer:
    loss=0.000170, grad_norm=0.135386, update_successful=True.

The MTP-disable conditional and trim arithmetic are identical to #6473's
vanilla_mbridge=False fix — same fix is being exercised on every DSv4
attention path that goes through _build_tf_config.

AI assistance disclosure

This change was developed with AI-assisted coding (Claude). The author has
reviewed every changed line and personally executed the test plan above.

Companion PR (defensive bridge-side default that complements this engine-side
fix): NVIDIA-NeMo/Megatron-Bridge#4003.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes the NPU MTP inference patch logic and introduces a fix for the vanilla MBridge path when MTP is disabled by zeroing out mtp_num_layers and trimming csa_compress_ratios. The reviewer noted that removing the NPU MTP patch entirely breaks NPU support and suggested using a try-except block to safely import the Megatron version instead. Additionally, the reviewer recommended using direct assignment instead of setdefault to guarantee that mtp_num_layers is correctly overridden and set to 0 when MTP is disabled.

Comment on lines 124 to 125
from verl.utils.device import is_cuda_available

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Completely removing the NPU MTP inference patch and the is_npu_available check breaks MTP inference support on NPU platforms. If the removal of __version__ from megatron.core.package_info was intended to prevent import errors on certain Megatron versions, we should instead import it safely using a try-except block locally, rather than dropping NPU support entirely.

Suggested change
from verl.utils.device import is_cuda_available
from verl.utils.device import is_cuda_available, is_npu_available
try:
from megatron.core.package_info import __version__
except ImportError:
__version__ = "0.0.0"
if is_npu_available and __version__ >= "0.16.0":
from verl.models.mcore.patch import apply_mtp_inference_patch
apply_mtp_inference_patch()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you choose to remove the logic here? I think this change is unrelated to this PR unless you have tested this on NPU (I don't really have access to NPU clusters so I can't say the changes here are reasonable or not) @Meirtz

Comment on lines +187 to +188
if not self.model_config.mtp.enable:
override_transformer_config.setdefault("mtp_num_layers", 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using setdefault here will not override "mtp_num_layers" if it is already present in override_transformer_config (e.g., if set by user configuration or previous steps). To guarantee that MTP is completely disabled when self.model_config.mtp.enable is False, we should explicitly set "mtp_num_layers" to 0 using direct assignment.

Suggested change
if not self.model_config.mtp.enable:
override_transformer_config.setdefault("mtp_num_layers", 0)
if not self.model_config.mtp.enable:
override_transformer_config["mtp_num_layers"] = 0

Copy link
Copy Markdown
Collaborator

@HollowMan6 HollowMan6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes LGTM and this can be merged before #6473, but I'm curious about why do you decide to remove that NPU mtp patch

Comment on lines 124 to 125
from verl.utils.device import is_cuda_available

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you choose to remove the logic here? I think this change is unrelated to this PR unless you have tested this on NPU (I don't really have access to NPU clusters so I can't say the changes here are reasonable or not) @Meirtz

@Meirtz Meirtz force-pushed the fix/dsv4-mtp-disable-vanilla-mbridge branch from dd451b7 to 7d1128c Compare May 29, 2026 14:54
@Meirtz
Copy link
Copy Markdown
Author

Meirtz commented May 29, 2026

Thanks @HollowMan6 — not intentional, the NPU patch got dropped while the branch was on an older base. Rebased onto latest main so it's fully restored, and the diff is now just the vanilla-mbridge fix. Also applied @gemini-code-assist's suggestion to use direct assignment for mtp_num_layers. Agreed on merging this before #6473.

…on vanilla_mbridge=True path

PR verl-project#6473 added the same fix to the vanilla_mbridge=False (NeMo MB) path of
MegatronEngine._build_tf_config. The vanilla_mbridge=True (ISEEKYAN/mbridge)
path needs the symmetric treatment: when self.model_config.mtp.enable is
False, force mtp_num_layers=0 so the bridge does not build MTP blocks, and
trim the per-layer csa_compress_ratios list (DSv4-Flash HF configs pad it for
the MTP layer when num_nextn_predict_layers > 0).

mtp_num_layers uses direct assignment (not setdefault) so a disabled-MTP run
always forces 0 even if override_transformer_config carried a stale value.

Why not duplicate: verl-project#6473 only modifies the vanilla_mbridge=False branch.
This PR modifies the vanilla_mbridge=True branch — different code path,
complementary fix.

Test plan: validated end-to-end on GB200 (1 GPU) through ISEEKYAN/mbridge +
DSv4 hybrid attention — forward + backward + optimizer.step() with the
vanilla=True path produces finite loss / finite grad_norm /
update_successful=True.

AI assistance disclosure: developed with AI-assisted coding (Claude); author
reviewed every changed line.

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Lingrui Mei <lmei@nvidia.com>
@Meirtz Meirtz force-pushed the fix/dsv4-mtp-disable-vanilla-mbridge branch from 7d1128c to e48aadc Compare May 29, 2026 15:19
@HollowMan6 HollowMan6 requested a review from wuxibin89 May 29, 2026 15:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants