diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 93e98f234d0..9d8b0df11c0 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -188,7 +188,14 @@ def _build_tf_config(self): from verl.models.mcore.mbridge import AutoBridge bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype) + if not self.model_config.mtp.enable: + override_transformer_config["mtp_num_layers"] = 0 bridge.set_extra_args(**override_transformer_config) + if not self.model_config.mtp.enable: + csa = getattr(bridge.config, "csa_compress_ratios", None) + num_layers = getattr(bridge.config, "num_layers", None) + if csa is not None and num_layers is not None and len(csa) > num_layers: + bridge.config.csa_compress_ratios = list(csa[:num_layers]) tf_config = bridge.config tf_config.fp16 = self.param_dtype == torch.float16 tf_config.bf16 = self.param_dtype == torch.bfloat16