From e48aadc9f884d99ece5a3c597f5f0c6d5b1474df Mon Sep 17 00:00:00 2001 From: Lingrui Mei Date: Fri, 29 May 2026 22:53:57 +0800 Subject: [PATCH] [megatron] fix: zero out mtp_num_layers and trim csa_compress_ratios on vanilla_mbridge=True path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #6473 added the same fix to the vanilla_mbridge=False (NeMo MB) path of MegatronEngine._build_tf_config. The vanilla_mbridge=True (ISEEKYAN/mbridge) path needs the symmetric treatment: when self.model_config.mtp.enable is False, force mtp_num_layers=0 so the bridge does not build MTP blocks, and trim the per-layer csa_compress_ratios list (DSv4-Flash HF configs pad it for the MTP layer when num_nextn_predict_layers > 0). mtp_num_layers uses direct assignment (not setdefault) so a disabled-MTP run always forces 0 even if override_transformer_config carried a stale value. Why not duplicate: #6473 only modifies the vanilla_mbridge=False branch. This PR modifies the vanilla_mbridge=True branch — different code path, complementary fix. Test plan: validated end-to-end on GB200 (1 GPU) through ISEEKYAN/mbridge + DSv4 hybrid attention — forward + backward + optimizer.step() with the vanilla=True path produces finite loss / finite grad_norm / update_successful=True. AI assistance disclosure: developed with AI-assisted coding (Claude); author reviewed every changed line. Co-authored-by: Claude Signed-off-by: Lingrui Mei --- verl/workers/engine/megatron/transformer_impl.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 93e98f234d0..9d8b0df11c0 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -188,7 +188,14 @@ def _build_tf_config(self): from verl.models.mcore.mbridge import AutoBridge bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype) + if not self.model_config.mtp.enable: + override_transformer_config["mtp_num_layers"] = 0 bridge.set_extra_args(**override_transformer_config) + if not self.model_config.mtp.enable: + csa = getattr(bridge.config, "csa_compress_ratios", None) + num_layers = getattr(bridge.config, "num_layers", None) + if csa is not None and num_layers is not None and len(csa) > num_layers: + bridge.config.csa_compress_ratios = list(csa[:num_layers]) tf_config = bridge.config tf_config.fp16 = self.param_dtype == torch.float16 tf_config.bf16 = self.param_dtype == torch.bfloat16