From bec9ef74768dd201881cd4e54cd0385e87caae27 Mon Sep 17 00:00:00 2001 From: Joel Date: Mon, 16 Mar 2026 10:41:30 +0800 Subject: [PATCH 1/2] [misc] chore: bump version to 0.7.1 (#5602) ### What does this PR do? Bump version to 0.7.1 --- verl/version/version | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/version/version b/verl/version/version index 7188dbafb43..39e898a4f95 100644 --- a/verl/version/version +++ b/verl/version/version @@ -1 +1 @@ -0.8.0.dev +0.7.1 From 062415275f29ec442890041dec6639ef5238ba02 Mon Sep 17 00:00:00 2001 From: kolehma8 Date: Thu, 28 May 2026 11:58:28 -0700 Subject: [PATCH 2/2] [megatron] feat: align optimizer states and DDP grad bucket with model precision In bf16 training the optimizer state (Adam m, v) and grad accumulation buffer were kept in fp32, costing ~3x the necessary memory. This change makes them match the model dtype: - `init_megatron_optim_config`: add a `bf16: bool = True` kwarg so the function can dispatch on the actual precision instead of inferring from a single `fp16` flag. The bf16 branch now enables `use_precision_aware_optimizer=True` and pins `main_grads_dtype`, `exp_avg_dtype`, `exp_avg_sq_dtype` to `torch.bfloat16`. `main_params_dtype` stays at fp32 because TE FusedAdam currently rejects bf16 master weights. fp16 path is unchanged; an explicit fp32 branch now exists and keeps Megatron's defaults. - `get_model` (megatron_utils): derive `grad_reduce_in_fp32 = not tfconfig.bf16` so the DDP grad bucket matches the optimizer's `main_grads_dtype`. fp16/fp32 paths keep the fp32 grad bucket. User overrides via `override_ddp_config` still win. - Update three call sites (`megatron_workers.py` actor + critic, `transformer_impl.py`) to pass `bf16=self.dtype == torch.bfloat16`. Verified end-to-end on Qwen3-30B-A3B GRPO: the optimizer config dump shows the bf16 dtypes landing, model loads, and step-0 validation completes (~51% gsm8k acc) before unrelated OOM from rollout colocation. Adds CPU-only unit tests in tests/utils/megatron/test_optimizer.py (9 tests, all passing) covering the bf16/fp16/fp32 dispatch, default-kwarg backward compatibility, override precedence, and basic field passthrough. Tests monkey-patch `OptimizerConfig` to verify which kwargs verl assembles without invoking Megatron's downstream validators. --- tests/utils/megatron/test_optimizer.py | 179 ++++++++++++++++++ verl/utils/megatron/optimizer.py | 36 +++- verl/utils/megatron_utils.py | 8 +- .../engine/megatron/transformer_impl.py | 1 + verl/workers/megatron_workers.py | 2 + 5 files changed, 223 insertions(+), 3 deletions(-) create mode 100644 tests/utils/megatron/test_optimizer.py diff --git a/tests/utils/megatron/test_optimizer.py b/tests/utils/megatron/test_optimizer.py new file mode 100644 index 00000000000..25cb3444094 --- /dev/null +++ b/tests/utils/megatron/test_optimizer.py @@ -0,0 +1,179 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the precision-aware dispatch in ``init_megatron_optim_config``. + +These tests stub out ``megatron.core.optimizer.OptimizerConfig`` so they can +run on CPU without TransformerEngine — the goal is to verify which kwargs +verl assembles for each precision mode, not Megatron's downstream validation. +""" + +from unittest.mock import MagicMock + +import pytest +import torch +from omegaconf import OmegaConf + +from verl.utils.megatron import optimizer as opt_mod +from verl.utils.megatron.optimizer import init_megatron_optim_config + + +def _base_optim_config(): + return OmegaConf.create( + { + "optimizer": "adam", + "lr": 1e-3, + "min_lr": 0.0, + "clip_grad": 1.0, + "weight_decay": 0.01, + } + ) + + +@pytest.fixture +def captured_args(monkeypatch): + """Replace ``OptimizerConfig`` with a recorder so we can inspect kwargs.""" + captured: dict = {} + + def _fake(**kwargs): + captured.clear() + captured.update(kwargs) + return MagicMock(name="OptimizerConfig", **kwargs) + + monkeypatch.setattr(opt_mod, "OptimizerConfig", _fake) + return captured + + +def test_bf16_branch_enables_precision_aware_optimizer_with_bf16_state(captured_args): + init_megatron_optim_config(_base_optim_config(), fp16=False, bf16=True) + + assert captured_args["bf16"] is True + assert captured_args["params_dtype"] is torch.bfloat16 + assert captured_args["use_precision_aware_optimizer"] is True + assert captured_args["main_grads_dtype"] is torch.bfloat16 + assert captured_args["exp_avg_dtype"] is torch.bfloat16 + assert captured_args["exp_avg_sq_dtype"] is torch.bfloat16 + # Master params dtype intentionally left at Megatron default (fp32) — + # TE FusedAdam rejects bf16 master at init. + assert "main_params_dtype" not in captured_args + + +def test_fp16_branch_uses_precision_aware_but_keeps_fp32_optimizer_state(captured_args): + init_megatron_optim_config(_base_optim_config(), fp16=True, bf16=False) + + assert captured_args["fp16"] is True + assert captured_args["bf16"] is False + assert captured_args["params_dtype"] is torch.float16 + assert captured_args["use_precision_aware_optimizer"] is True + assert captured_args["initial_loss_scale"] == 32768 + assert captured_args["min_loss_scale"] == 1 + assert captured_args["store_param_remainders"] is False + # Adam moment / grad dtypes left at Megatron's fp32 default in fp16 mode. + assert "main_grads_dtype" not in captured_args + assert "exp_avg_dtype" not in captured_args + assert "exp_avg_sq_dtype" not in captured_args + + +def test_fp32_branch_disables_precision_aware_optimizer(captured_args): + init_megatron_optim_config(_base_optim_config(), fp16=False, bf16=False) + + assert captured_args["fp16"] is False + assert captured_args["bf16"] is False + assert captured_args["params_dtype"] is torch.float32 + # Precision-aware optimizer must stay off — Megatron asserts the dtype + # fields equal fp32 when it's disabled. + assert "use_precision_aware_optimizer" not in captured_args + assert "main_grads_dtype" not in captured_args + assert "exp_avg_dtype" not in captured_args + assert "exp_avg_sq_dtype" not in captured_args + + +def test_default_kwargs_dispatch_to_bf16_branch(captured_args): + """Backward compatibility: callers that omit ``bf16`` get the bf16 path.""" + init_megatron_optim_config(_base_optim_config()) + + assert captured_args["bf16"] is True + assert captured_args["params_dtype"] is torch.bfloat16 + assert captured_args["use_precision_aware_optimizer"] is True + + +def test_fp16_wins_over_bf16_when_both_true(captured_args): + init_megatron_optim_config(_base_optim_config(), fp16=True, bf16=True) + + assert captured_args["fp16"] is True + assert captured_args["params_dtype"] is torch.float16 + # bf16-branch-only fields must not appear when fp16 is selected. + assert "main_grads_dtype" not in captured_args + assert "exp_avg_dtype" not in captured_args + + +def test_use_distributed_optimizer_passes_through(captured_args): + init_megatron_optim_config(_base_optim_config(), use_distributed_optimizer=False) + assert captured_args["use_distributed_optimizer"] is False + + init_megatron_optim_config(_base_optim_config(), use_distributed_optimizer=True) + assert captured_args["use_distributed_optimizer"] is True + + +def test_basic_optim_config_fields_pass_through(captured_args): + cfg = OmegaConf.create( + { + "optimizer": "sgd", + "lr": 5e-4, + "min_lr": 1e-5, + "clip_grad": 0.5, + "weight_decay": 0.1, + } + ) + init_megatron_optim_config(cfg) + + assert captured_args["optimizer"] == "sgd" + assert captured_args["lr"] == pytest.approx(5e-4) + assert captured_args["min_lr"] == pytest.approx(1e-5) + assert captured_args["clip_grad"] == pytest.approx(0.5) + assert captured_args["weight_decay"] == pytest.approx(0.1) + + +def test_override_optimizer_config_overrides_branch_defaults(captured_args): + cfg = OmegaConf.create( + { + "optimizer": "adam", + "lr": 1e-3, + "min_lr": 0.0, + "clip_grad": 1.0, + "weight_decay": 0.01, + "override_optimizer_config": { + "use_precision_aware_optimizer": False, + "exp_avg_dtype": "sentinel-override", + }, + } + ) + init_megatron_optim_config(cfg, bf16=True) + + # User-supplied overrides win over the bf16-branch defaults … + assert captured_args["use_precision_aware_optimizer"] is False + assert captured_args["exp_avg_dtype"] == "sentinel-override" + # … but non-overridden bf16 defaults remain. + assert captured_args["main_grads_dtype"] is torch.bfloat16 + assert captured_args["exp_avg_sq_dtype"] is torch.bfloat16 + + +def test_missing_override_config_leaves_branch_defaults_intact(captured_args): + """``optim_config.get('override_optimizer_config', {})`` must not crash when absent.""" + cfg = _base_optim_config() + assert "override_optimizer_config" not in cfg + + init_megatron_optim_config(cfg, bf16=True) + + assert captured_args["use_precision_aware_optimizer"] is True + assert captured_args["exp_avg_dtype"] is torch.bfloat16 diff --git a/verl/utils/megatron/optimizer.py b/verl/utils/megatron/optimizer.py index 625cd24b921..c71655f3275 100644 --- a/verl/utils/megatron/optimizer.py +++ b/verl/utils/megatron/optimizer.py @@ -22,7 +22,10 @@ def init_megatron_optim_config( - optim_config: dict, use_distributed_optimizer: bool = True, fp16: bool = False + optim_config: dict, + use_distributed_optimizer: bool = True, + fp16: bool = False, + bf16: bool = True, ) -> OptimizerConfig: optim_args = { "optimizer": optim_config.optimizer, @@ -44,11 +47,40 @@ def init_megatron_optim_config( "store_param_remainders": False, } ) - else: # bf16 mode + elif bf16: + # Match precision: keep the grad-accumulation buffer and Adam + # moments (m, v) in bf16 so optimizer-state memory tracks the + # model dtype. Master parameters stay fp32 (Megatron default + # `main_params_dtype`) because TE FusedAdam currently rejects + # bf16 master weights at init (only fp32/fp16 accepted). The + # int16 "store_param_remainders" path (Megatron default True + # in bf16 mode) already eliminates the fp32 master buffer in + # favor of bf16 working + int16 remainders, achieving the same + # ~50% master-memory reduction. + # Requires TransformerEngine's FusedAdam (already needed by + # the precision-aware optimizer path). Override any of these + # via `override_optimizer_config` to opt back into fp32. optim_args.update( { "bf16": True, "params_dtype": torch.bfloat16, + "use_precision_aware_optimizer": True, + "main_grads_dtype": torch.bfloat16, + "exp_avg_dtype": torch.bfloat16, + "exp_avg_sq_dtype": torch.bfloat16, + } + ) + else: + # fp32 mode: leave grad-accumulation buffer and Adam moments at + # Megatron's default torch.float32. Do not enable the precision-aware + # optimizer — it's only beneficial when a moment/grad dtype is below + # fp32, and Megatron asserts the dtype fields equal fp32 whenever the + # precision-aware optimizer is off (optimizer_config.py:258-268). + optim_args.update( + { + "bf16": False, + "fp16": False, + "params_dtype": torch.float32, } ) override_config = optim_config.get("override_optimizer_config", {}) diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index dba32d02539..69a2d9e52da 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -142,10 +142,16 @@ def get_model( model = [Float16Module(config, model_module) for model_module in model] if wrap_with_ddp: + # Derive the grad-bucket dtype from model precision so the DDP buffer + # matches the optimizer's main_grads_dtype: bf16 weights → bf16 grad + # bucket (paired with the precision-aware optimizer in + # init_megatron_optim_config); fp16/fp32 weights → fp32 grad bucket. + # User overrides via `override_ddp_config` still win. + grad_reduce_in_fp32 = not getattr(tfconfig, "bf16", False) ddp_models = [] ddp_config_dict = { "use_distributed_optimizer": use_distributed_optimizer, - "grad_reduce_in_fp32": True, + "grad_reduce_in_fp32": grad_reduce_in_fp32, "overlap_grad_reduce": False, } if override_ddp_config is not None: diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 871b6f8c61c..651213c0e7d 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -282,6 +282,7 @@ def _build_optimizer(self): self.optimizer_config, use_distributed_optimizer=self.engine_config.use_distributed_optimizer, fp16=self.param_dtype == torch.float16, + bf16=self.param_dtype == torch.bfloat16, ) optimizer = get_megatron_optimizer(model=self.module, config=optim_config_megatron) register_megatron_training_hooks(self.module, optimizer) diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 31dfe7e9341..31749b9b9fe 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -489,6 +489,7 @@ def _build_model_optimizer( optim_config, use_distributed_optimizer=wrap_config.use_distributed_optimizer, fp16=self.dtype == torch.float16, + bf16=self.dtype == torch.bfloat16, ) actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config_megatron) actor_optimizer_scheduler = get_megatron_optimizer_param_scheduler( @@ -1134,6 +1135,7 @@ def _build_critic_model_optimizer( optim_config, use_distributed_optimizer=wrap_config.use_distributed_optimizer, fp16=self.dtype == torch.float16, + bf16=self.dtype == torch.bfloat16, ) critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config_megatron) critic_optimizer_scheduler = get_megatron_optimizer_param_scheduler(