[algo] fix: normalize GDPO advantages over responses#6497
Conversation
There was a problem hiding this comment.
Code Review
This pull request modifies the advantage calculation in compute_gdpo_outcome_advantage to compute and whiten advantages at the response level rather than the token level. The reviewer identified a critical issue where whitening a single valid sequence (or a batch size of 1) can cause a division-by-zero crash, and provided a code suggestion to safely handle this edge case.
| response_level_advantage = verl_F.masked_mean(new_advantage, response_mask, axis=-1) | ||
| response_level_mask = response_mask.sum(dim=-1) > 0 | ||
| response_level_advantage = verl_F.masked_whiten(response_level_advantage, response_level_mask) |
There was a problem hiding this comment.
If the batch size is 1, or if only one sequence in the batch has a valid response (i.e., response_level_mask.sum() is 1), verl_F.masked_whiten will call masked_var which raises a ValueError: The sum of the mask is one, which can cause a division by zero. This will crash the training or validation loop.
To prevent this, we should check if response_level_mask.sum() > 1 before applying masked_whiten. If there is at most one valid sequence, we can safely set the whitened advantages to zero. Additionally, converting response_level_mask to the same dtype as response_level_advantage ensures compatibility across different PyTorch versions and hardware backends.
| response_level_advantage = verl_F.masked_mean(new_advantage, response_mask, axis=-1) | |
| response_level_mask = response_mask.sum(dim=-1) > 0 | |
| response_level_advantage = verl_F.masked_whiten(response_level_advantage, response_level_mask) | |
| response_level_advantage = verl_F.masked_mean(new_advantage, response_mask, axis=-1) | |
| response_level_mask = (response_mask.sum(dim=-1) > 0).to(response_level_advantage.dtype) | |
| if response_level_mask.sum() > 1: | |
| response_level_advantage = verl_F.masked_whiten(response_level_advantage, response_level_mask) | |
| else: | |
| response_level_advantage = torch.zeros_like(response_level_advantage) |
What does this PR do?
This PR updates GDPO's final batch-level advantage normalization to operate over generated responses instead of valid response tokens.
In the current GDPO implementation,
compute_grpo_outcome_advantagereturns outcome-level advantages with token shape, i.e. one scalar advantage is broadcast to all valid response tokens. The final GDPO normalization then applies:This computes whitening statistics over all valid tokens. As a result, longer responses contribute more copies of the same outcome-level advantage to the batch mean/std.
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.