[megatron] feat: align optimizer states and DDP grad bucket with model precision#6526
Open
kolehma8 wants to merge 2 commits into
Open
[megatron] feat: align optimizer states and DDP grad bucket with model precision#6526kolehma8 wants to merge 2 commits into
kolehma8 wants to merge 2 commits into
Conversation
### What does this PR do? Bump version to 0.7.1
…l precision In bf16 training the optimizer state (Adam m, v) and grad accumulation buffer were kept in fp32, costing ~3x the necessary memory. This change makes them match the model dtype: - `init_megatron_optim_config`: add a `bf16: bool = True` kwarg so the function can dispatch on the actual precision instead of inferring from a single `fp16` flag. The bf16 branch now enables `use_precision_aware_optimizer=True` and pins `main_grads_dtype`, `exp_avg_dtype`, `exp_avg_sq_dtype` to `torch.bfloat16`. `main_params_dtype` stays at fp32 because TE FusedAdam currently rejects bf16 master weights. fp16 path is unchanged; an explicit fp32 branch now exists and keeps Megatron's defaults. - `get_model` (megatron_utils): derive `grad_reduce_in_fp32 = not tfconfig.bf16` so the DDP grad bucket matches the optimizer's `main_grads_dtype`. fp16/fp32 paths keep the fp32 grad bucket. User overrides via `override_ddp_config` still win. - Update three call sites (`megatron_workers.py` actor + critic, `transformer_impl.py`) to pass `bf16=self.dtype == torch.bfloat16`. Verified end-to-end on Qwen3-30B-A3B GRPO: the optimizer config dump shows the bf16 dtypes landing, model loads, and step-0 validation completes (~51% gsm8k acc) before unrelated OOM from rollout colocation. Adds CPU-only unit tests in tests/utils/megatron/test_optimizer.py (9 tests, all passing) covering the bf16/fp16/fp32 dispatch, default-kwarg backward compatibility, override precedence, and basic field passthrough. Tests monkey-patch `OptimizerConfig` to verify which kwargs verl assembles without invoking Megatron's downstream validators.
Contributor
There was a problem hiding this comment.
Code Review
This pull request introduces precision-aware optimizer configuration dispatching for Megatron, specifically adding support for bf16 mode. It updates init_megatron_optim_config to configure precision-aware optimizer parameters (such as main_grads_dtype, exp_avg_dtype, and exp_avg_sq_dtype) when bf16 is enabled, and adds an explicit fallback for fp32 mode. Additionally, it derives the DDP gradient reduction precision (grad_reduce_in_fp32) based on the model's bf16 configuration, updates the worker implementations to pass the bf16 flag, and includes a comprehensive suite of unit tests to verify these configurations. There are no review comments to assess, and I have no additional feedback to provide.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
In bf16 training, the Megatron distributed-optimizer kept the Adam moments (m, v) and the grad-accumulation buffer in fp32 — 3× the necessary memory for those buffers. This PR makes them follow the model dtype, mirroring what Megatron's own CLI does with
--use-precision-aware-optimizer --main-grads-dtype bf16 --exp-avg-dtype bf16 --exp-avg-sq-dtype bf16. The DDP grad bucket dtype is derived from the same signal so it stays consistent with the optimizer's expected grad dtype.Checklist Before Starting
[{modules}] {type}: {description}—[megatron] feat: ...Test
CPU-only unit tests added at
tests/utils/megatron/test_optimizer.py(9 tests, ~9s, no GPU/TE required). They monkey-patchOptimizerConfigwith a recorder so they verify which kwargs verl assembles for each precision mode without invoking Megatron's downstream validators or TE FusedAdam. Coverage:main_grads_dtype/exp_avg_dtype/exp_avg_sq_dtypeto bf16;main_params_dtypeintentionally left unset.store_param_remainders=False).bf16=False, fp16=False) keeps Megatron defaults and does not enable precision-aware (Megatron asserts dtype fields must equal fp32 otherwise).fp16=True, bf16=True→ fp16 wins (if/elif precedence).use_distributed_optimizerpasses through.override_optimizer_configwins over branch defaults; non-overridden fields preserved.override_optimizer_configdoesn't crash.Manual end-to-end validation on Qwen3-30B-A3B GRPO (gsm8k, 8×H100, EP=8, mbs=1):
API and Usage Example
init_megatron_optim_configgains abf16: bool = Truekwarg so callers can dispatch on the actual precision instead of inferring from a singlefp16flag. The new kwarg defaults to True, preserving previous behavior for any caller that omits it (the old unlabeledelsebranch was effectively bf16).Design & Code Changes
Behavior matrix:
Files changed:
verl/utils/megatron/optimizer.py—init_megatron_optim_configrestructured into three branches; newbf16: bool = Truekwarg.verl/utils/megatron_utils.py— Inget_model, derivegrad_reduce_in_fp32 = not getattr(tfconfig, "bf16", False)so the DDP grad bucket dtype matchesmain_grads_dtype. User overrides viaoverride_ddp_configstill win.verl/workers/megatron_workers.py— actor + critic call sites passbf16=self.dtype == torch.bfloat16.verl/workers/engine/megatron/transformer_impl.py— same.tests/utils/megatron/test_optimizer.py— new test file (9 tests).Backward compatibility & known limits:
bf16kwarg defaults toTrue, so existing callers that pass onlyfp16=...keep bf16 dispatch — matches the prior implicit-bf16elsebranch behavior. Only new callers wanting the explicit fp32 path needbf16=False.main_params_dtypeis not dropped to bf16. TE FusedAdam currently rejects bf16 master weights at init (RuntimeError: FusedAdam only supports fp32/fp16 master weights.). NVIDIA's nightly docs describe future support; once TE catches up, we can revisit.Checklist Before Submitting
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaystests/utils/megatron/test_optimizer.py, 9 CPU-only tests.ci-requestchannel in the verl Slack.