Skip to content

[megatron] feat: align optimizer states and DDP grad bucket with model precision#6526

Open
kolehma8 wants to merge 2 commits into
verl-project:mainfrom
kolehma8:kolehma8/megatron
Open

[megatron] feat: align optimizer states and DDP grad bucket with model precision#6526
kolehma8 wants to merge 2 commits into
verl-project:mainfrom
kolehma8:kolehma8/megatron

Conversation

@kolehma8
Copy link
Copy Markdown

@kolehma8 kolehma8 commented May 28, 2026

What does this PR do?

In bf16 training, the Megatron distributed-optimizer kept the Adam moments (m, v) and the grad-accumulation buffer in fp32 — 3× the necessary memory for those buffers. This PR makes them follow the model dtype, mirroring what Megatron's own CLI does with --use-precision-aware-optimizer --main-grads-dtype bf16 --exp-avg-dtype bf16 --exp-avg-sq-dtype bf16. The DDP grad bucket dtype is derived from the same signal so it stays consistent with the optimizer's expected grad dtype.

Checklist Before Starting

Test

CPU-only unit tests added at tests/utils/megatron/test_optimizer.py (9 tests, ~9s, no GPU/TE required). They monkey-patch OptimizerConfig with a recorder so they verify which kwargs verl assembles for each precision mode without invoking Megatron's downstream validators or TE FusedAdam. Coverage:

  • bf16 branch enables precision-aware optimizer and pins main_grads_dtype / exp_avg_dtype / exp_avg_sq_dtype to bf16; main_params_dtype intentionally left unset.
  • fp16 branch enables precision-aware but leaves Adam moment dtypes at fp32 (with loss scaler + store_param_remainders=False).
  • fp32 branch (both bf16=False, fp16=False) keeps Megatron defaults and does not enable precision-aware (Megatron asserts dtype fields must equal fp32 otherwise).
  • Default kwargs dispatch to bf16 — backward compat for existing callers.
  • fp16=True, bf16=True → fp16 wins (if/elif precedence).
  • use_distributed_optimizer passes through.
  • Top-level config fields (optimizer/lr/min_lr/clip_grad/weight_decay) pass through verbatim.
  • override_optimizer_config wins over branch defaults; non-overridden fields preserved.
  • Missing override_optimizer_config doesn't crash.

Manual end-to-end validation on Qwen3-30B-A3B GRPO (gsm8k, 8×H100, EP=8, mbs=1):

  • Optimizer config dump in worker logs shows the new fields landing:
optimizer config after override: {..., 'use_precision_aware_optimizer': True,
 'main_grads_dtype': torch.bfloat16, 'exp_avg_dtype': torch.bfloat16,
 'exp_avg_sq_dtype': torch.bfloat16}
  • Megatron model loads cleanly, step-0 validation completes (~51% gsm8k acc), forward + backward pass + optimizer step execute.

API and Usage Example

init_megatron_optim_config gains a bf16: bool = True kwarg so callers can dispatch on the actual precision instead of inferring from a single fp16 flag. The new kwarg defaults to True, preserving previous behavior for any caller that omits it (the old unlabeled else branch was effectively bf16).

# Existing call sites updated:
optim_config_megatron = init_megatron_optim_config(
    optim_config,
    use_distributed_optimizer=wrap_config.use_distributed_optimizer,
    fp16=self.dtype == torch.float16,
    bf16=self.dtype == torch.bfloat16,   # NEW
)

# Opt back into the old fp32 optimizer-state layout via override:
# +actor_rollout_ref.actor.megatron.override_optimizer_config.use_precision_aware_optimizer=False
# +actor_rollout_ref.actor.megatron.override_ddp_config.grad_reduce_in_fp32=True

Design & Code Changes

Behavior matrix:

Mode param dtype DDP grad bucket Adam m, v Master weights precision-aware opt
bf16 (new default) bf16 bf16 bf16 fp32 (TE limit) True
fp16 fp16 fp32 fp32 fp32 True
fp32 (new explicit branch) fp32 fp32 fp32 fp32 False

Files changed:

  • verl/utils/megatron/optimizer.pyinit_megatron_optim_config restructured into three branches; new bf16: bool = True kwarg.
  • verl/utils/megatron_utils.py — In get_model, derive grad_reduce_in_fp32 = not getattr(tfconfig, "bf16", False) so the DDP grad bucket dtype matches main_grads_dtype. User overrides via override_ddp_config still win.
  • verl/workers/megatron_workers.py — actor + critic call sites pass bf16=self.dtype == torch.bfloat16.
  • verl/workers/engine/megatron/transformer_impl.py — same.
  • tests/utils/megatron/test_optimizer.py — new test file (9 tests).

Backward compatibility & known limits:

  • The new bf16 kwarg defaults to True, so existing callers that pass only fp16=... keep bf16 dispatch — matches the prior implicit-bf16 else branch behavior. Only new callers wanting the explicit fp32 path need bf16=False.
  • Existing bf16 runs will see lower optimizer-state memory but slightly different numerics: the precision-aware optimizer uses stochastic-rounding tricks to recover precision, but updates are not bit-equivalent to the prior fp32-state path. Worth a CHANGELOG entry.
  • main_params_dtype is not dropped to bf16. TE FusedAdam currently rejects bf16 master weights at init (RuntimeError: FusedAdam only supports fp32/fp16 master weights.). NVIDIA's nightly docs describe future support; once TE catches up, we can revisit.

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
  • Add / Update the documentation.
  • Add unit or end-to-end test(s) — tests/utils/megatron/test_optimizer.py, 9 CPU-only tests.
  • Once your PR is ready for CI, send a message in the ci-request channel in the verl Slack.

wuxibin89 and others added 2 commits March 16, 2026 10:41
### What does this PR do?

Bump version to 0.7.1
…l precision

In bf16 training the optimizer state (Adam m, v) and grad accumulation buffer were kept in fp32, costing ~3x the necessary memory. This change makes them match the model dtype:

- `init_megatron_optim_config`: add a `bf16: bool = True` kwarg so the function can dispatch on the actual precision instead of inferring from a single `fp16` flag. The bf16 branch now enables `use_precision_aware_optimizer=True` and pins `main_grads_dtype`, `exp_avg_dtype`, `exp_avg_sq_dtype` to `torch.bfloat16`. `main_params_dtype` stays at fp32 because TE FusedAdam currently rejects bf16 master weights. fp16 path is unchanged; an explicit fp32 branch now exists and keeps Megatron's defaults.
- `get_model` (megatron_utils): derive `grad_reduce_in_fp32 = not tfconfig.bf16` so the DDP grad bucket matches the optimizer's `main_grads_dtype`. fp16/fp32 paths keep the fp32 grad bucket. User overrides via `override_ddp_config` still win.
- Update three call sites (`megatron_workers.py` actor + critic, `transformer_impl.py`) to pass `bf16=self.dtype == torch.bfloat16`.

Verified end-to-end on Qwen3-30B-A3B GRPO: the optimizer config dump shows the bf16 dtypes landing, model loads, and step-0 validation completes (~51% gsm8k acc) before unrelated OOM from rollout colocation.

Adds CPU-only unit tests in tests/utils/megatron/test_optimizer.py (9 tests, all passing) covering the bf16/fp16/fp32 dispatch, default-kwarg backward compatibility, override precedence, and basic field passthrough. Tests monkey-patch `OptimizerConfig` to verify which kwargs verl assembles without invoking Megatron's downstream validators.
@CLAassistant
Copy link
Copy Markdown

CLAassistant commented May 28, 2026

CLA assistant check
All committers have signed the CLA.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces precision-aware optimizer configuration dispatching for Megatron, specifically adding support for bf16 mode. It updates init_megatron_optim_config to configure precision-aware optimizer parameters (such as main_grads_dtype, exp_avg_dtype, and exp_avg_sq_dtype) when bf16 is enabled, and adds an explicit fallback for fp32 mode. Additionally, it derives the DDP gradient reduction precision (grad_reduce_in_fp32) based on the model's bf16 configuration, updates the worker implementations to pass the bf16 flag, and includes a comprehensive suite of unit tests to verify these configurations. There are no review comments to assess, and I have no additional feedback to provide.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants