Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions tests/utils/megatron/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the precision-aware dispatch in ``init_megatron_optim_config``.

These tests stub out ``megatron.core.optimizer.OptimizerConfig`` so they can
run on CPU without TransformerEngine — the goal is to verify which kwargs
verl assembles for each precision mode, not Megatron's downstream validation.
"""

from unittest.mock import MagicMock

import pytest
import torch
from omegaconf import OmegaConf

from verl.utils.megatron import optimizer as opt_mod
from verl.utils.megatron.optimizer import init_megatron_optim_config


def _base_optim_config():
return OmegaConf.create(
{
"optimizer": "adam",
"lr": 1e-3,
"min_lr": 0.0,
"clip_grad": 1.0,
"weight_decay": 0.01,
}
)


@pytest.fixture
def captured_args(monkeypatch):
"""Replace ``OptimizerConfig`` with a recorder so we can inspect kwargs."""
captured: dict = {}

def _fake(**kwargs):
captured.clear()
captured.update(kwargs)
return MagicMock(name="OptimizerConfig", **kwargs)

monkeypatch.setattr(opt_mod, "OptimizerConfig", _fake)
return captured


def test_bf16_branch_enables_precision_aware_optimizer_with_bf16_state(captured_args):
init_megatron_optim_config(_base_optim_config(), fp16=False, bf16=True)

assert captured_args["bf16"] is True
assert captured_args["params_dtype"] is torch.bfloat16
assert captured_args["use_precision_aware_optimizer"] is True
assert captured_args["main_grads_dtype"] is torch.bfloat16
assert captured_args["exp_avg_dtype"] is torch.bfloat16
assert captured_args["exp_avg_sq_dtype"] is torch.bfloat16
# Master params dtype intentionally left at Megatron default (fp32) —
# TE FusedAdam rejects bf16 master at init.
assert "main_params_dtype" not in captured_args


def test_fp16_branch_uses_precision_aware_but_keeps_fp32_optimizer_state(captured_args):
init_megatron_optim_config(_base_optim_config(), fp16=True, bf16=False)

assert captured_args["fp16"] is True
assert captured_args["bf16"] is False
assert captured_args["params_dtype"] is torch.float16
assert captured_args["use_precision_aware_optimizer"] is True
assert captured_args["initial_loss_scale"] == 32768
assert captured_args["min_loss_scale"] == 1
assert captured_args["store_param_remainders"] is False
# Adam moment / grad dtypes left at Megatron's fp32 default in fp16 mode.
assert "main_grads_dtype" not in captured_args
assert "exp_avg_dtype" not in captured_args
assert "exp_avg_sq_dtype" not in captured_args


def test_fp32_branch_disables_precision_aware_optimizer(captured_args):
init_megatron_optim_config(_base_optim_config(), fp16=False, bf16=False)

assert captured_args["fp16"] is False
assert captured_args["bf16"] is False
assert captured_args["params_dtype"] is torch.float32
# Precision-aware optimizer must stay off — Megatron asserts the dtype
# fields equal fp32 when it's disabled.
assert "use_precision_aware_optimizer" not in captured_args
assert "main_grads_dtype" not in captured_args
assert "exp_avg_dtype" not in captured_args
assert "exp_avg_sq_dtype" not in captured_args


def test_default_kwargs_dispatch_to_bf16_branch(captured_args):
"""Backward compatibility: callers that omit ``bf16`` get the bf16 path."""
init_megatron_optim_config(_base_optim_config())

assert captured_args["bf16"] is True
assert captured_args["params_dtype"] is torch.bfloat16
assert captured_args["use_precision_aware_optimizer"] is True


def test_fp16_wins_over_bf16_when_both_true(captured_args):
init_megatron_optim_config(_base_optim_config(), fp16=True, bf16=True)

assert captured_args["fp16"] is True
assert captured_args["params_dtype"] is torch.float16
# bf16-branch-only fields must not appear when fp16 is selected.
assert "main_grads_dtype" not in captured_args
assert "exp_avg_dtype" not in captured_args


def test_use_distributed_optimizer_passes_through(captured_args):
init_megatron_optim_config(_base_optim_config(), use_distributed_optimizer=False)
assert captured_args["use_distributed_optimizer"] is False

init_megatron_optim_config(_base_optim_config(), use_distributed_optimizer=True)
assert captured_args["use_distributed_optimizer"] is True


def test_basic_optim_config_fields_pass_through(captured_args):
cfg = OmegaConf.create(
{
"optimizer": "sgd",
"lr": 5e-4,
"min_lr": 1e-5,
"clip_grad": 0.5,
"weight_decay": 0.1,
}
)
init_megatron_optim_config(cfg)

assert captured_args["optimizer"] == "sgd"
assert captured_args["lr"] == pytest.approx(5e-4)
assert captured_args["min_lr"] == pytest.approx(1e-5)
assert captured_args["clip_grad"] == pytest.approx(0.5)
assert captured_args["weight_decay"] == pytest.approx(0.1)


def test_override_optimizer_config_overrides_branch_defaults(captured_args):
cfg = OmegaConf.create(
{
"optimizer": "adam",
"lr": 1e-3,
"min_lr": 0.0,
"clip_grad": 1.0,
"weight_decay": 0.01,
"override_optimizer_config": {
"use_precision_aware_optimizer": False,
"exp_avg_dtype": "sentinel-override",
},
}
)
init_megatron_optim_config(cfg, bf16=True)

# User-supplied overrides win over the bf16-branch defaults …
assert captured_args["use_precision_aware_optimizer"] is False
assert captured_args["exp_avg_dtype"] == "sentinel-override"
# … but non-overridden bf16 defaults remain.
assert captured_args["main_grads_dtype"] is torch.bfloat16
assert captured_args["exp_avg_sq_dtype"] is torch.bfloat16


def test_missing_override_config_leaves_branch_defaults_intact(captured_args):
"""``optim_config.get('override_optimizer_config', {})`` must not crash when absent."""
cfg = _base_optim_config()
assert "override_optimizer_config" not in cfg

init_megatron_optim_config(cfg, bf16=True)

assert captured_args["use_precision_aware_optimizer"] is True
assert captured_args["exp_avg_dtype"] is torch.bfloat16
36 changes: 34 additions & 2 deletions verl/utils/megatron/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@


def init_megatron_optim_config(
optim_config: dict, use_distributed_optimizer: bool = True, fp16: bool = False
optim_config: dict,
use_distributed_optimizer: bool = True,
fp16: bool = False,
bf16: bool = True,
) -> OptimizerConfig:
optim_args = {
"optimizer": optim_config.optimizer,
Expand All @@ -44,11 +47,40 @@ def init_megatron_optim_config(
"store_param_remainders": False,
}
)
else: # bf16 mode
elif bf16:
# Match precision: keep the grad-accumulation buffer and Adam
# moments (m, v) in bf16 so optimizer-state memory tracks the
# model dtype. Master parameters stay fp32 (Megatron default
# `main_params_dtype`) because TE FusedAdam currently rejects
# bf16 master weights at init (only fp32/fp16 accepted). The
# int16 "store_param_remainders" path (Megatron default True
# in bf16 mode) already eliminates the fp32 master buffer in
# favor of bf16 working + int16 remainders, achieving the same
# ~50% master-memory reduction.
# Requires TransformerEngine's FusedAdam (already needed by
# the precision-aware optimizer path). Override any of these
# via `override_optimizer_config` to opt back into fp32.
optim_args.update(
{
"bf16": True,
"params_dtype": torch.bfloat16,
"use_precision_aware_optimizer": True,
"main_grads_dtype": torch.bfloat16,
"exp_avg_dtype": torch.bfloat16,
"exp_avg_sq_dtype": torch.bfloat16,
}
)
else:
# fp32 mode: leave grad-accumulation buffer and Adam moments at
# Megatron's default torch.float32. Do not enable the precision-aware
# optimizer — it's only beneficial when a moment/grad dtype is below
# fp32, and Megatron asserts the dtype fields equal fp32 whenever the
# precision-aware optimizer is off (optimizer_config.py:258-268).
optim_args.update(
{
"bf16": False,
"fp16": False,
"params_dtype": torch.float32,
}
)
override_config = optim_config.get("override_optimizer_config", {})
Expand Down
8 changes: 7 additions & 1 deletion verl/utils/megatron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,16 @@ def get_model(
model = [Float16Module(config, model_module) for model_module in model]

if wrap_with_ddp:
# Derive the grad-bucket dtype from model precision so the DDP buffer
# matches the optimizer's main_grads_dtype: bf16 weights → bf16 grad
# bucket (paired with the precision-aware optimizer in
# init_megatron_optim_config); fp16/fp32 weights → fp32 grad bucket.
# User overrides via `override_ddp_config` still win.
grad_reduce_in_fp32 = not getattr(tfconfig, "bf16", False)
ddp_models = []
ddp_config_dict = {
"use_distributed_optimizer": use_distributed_optimizer,
"grad_reduce_in_fp32": True,
"grad_reduce_in_fp32": grad_reduce_in_fp32,
"overlap_grad_reduce": False,
}
if override_ddp_config is not None:
Expand Down
2 changes: 1 addition & 1 deletion verl/version/version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.8.0.dev
0.7.1
1 change: 1 addition & 0 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def _build_optimizer(self):
self.optimizer_config,
use_distributed_optimizer=self.engine_config.use_distributed_optimizer,
fp16=self.param_dtype == torch.float16,
bf16=self.param_dtype == torch.bfloat16,
)
optimizer = get_megatron_optimizer(model=self.module, config=optim_config_megatron)
register_megatron_training_hooks(self.module, optimizer)
Expand Down
2 changes: 2 additions & 0 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def _build_model_optimizer(
optim_config,
use_distributed_optimizer=wrap_config.use_distributed_optimizer,
fp16=self.dtype == torch.float16,
bf16=self.dtype == torch.bfloat16,
)
actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config_megatron)
actor_optimizer_scheduler = get_megatron_optimizer_param_scheduler(
Expand Down Expand Up @@ -1134,6 +1135,7 @@ def _build_critic_model_optimizer(
optim_config,
use_distributed_optimizer=wrap_config.use_distributed_optimizer,
fp16=self.dtype == torch.float16,
bf16=self.dtype == torch.bfloat16,
)
critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config_megatron)
critic_optimizer_scheduler = get_megatron_optimizer_param_scheduler(
Expand Down