fix(utils): ignore masked invalid normalization values#1347
fix(utils): ignore masked invalid normalization values#1347haoyang9804 wants to merge 9 commits into
Conversation
fishcrap
left a comment
There was a problem hiding this comment.
Good fix for a real silent-corruption bug. Two inline suggestions for robustness.
| x_masked = x * mask | ||
| factor = mask.sum(dim, keepdim=True) | ||
| x_centered = x_masked - mean * mask # only apply mean where mask is 1 | ||
| x_safe = torch.where(mask.bool(), x, 0.0) |
There was a problem hiding this comment.
Pull request overview
This PR hardens Normalization so masked-out non-finite values do not contaminate masked mean/std reductions or final normalized outputs.
Changes:
- Replaces mask multiplication with
torch.whereselection in normalization mean/std paths. - Ensures masked positions are zeroed during centering when
loss_maskis provided. - Adds regression tests for masked
NaN/Infvalues in batch and group normalization.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
areal/utils/data.py |
Updates Normalization masking logic to ignore invalid values outside the loss mask. |
tests/test_adv_norm_config.py |
Adds regression coverage for masked invalid values under batch and group normalization. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
fishcrap
left a comment
There was a problem hiding this comment.
_compute_mean and _compute_std look good now. One remaining inconsistency in __call__.
|
seems that Google Could auth failed in CI/CD:
|
fishcrap
left a comment
There was a problem hiding this comment.
The torch.where approach silently replaces NaN with 0, masking upstream issues (model divergence, -inf logprobs, etc.) that should be surfaced. Suggest logging a warning instead of hiding the problem.
I guess there is a gap. |
Thanks for the fix — the However, the worry isn't that So I think we should keep the |
no problem. let me add a warning. And many thanks for your careful review!! I sincerely learned a lot. |
Summary
AReaL's
Normalizationpath forreward_normandadv_normcan let an invalid value from a masked-out token corrupt every valid normalized reward or advantage. The trigger is quiet:Normalization._compute_mean()and_compute_std()multiply tensors byloss_mask, butNaN * 0andInf * 0remain non-finite. A single invalid value whereloss_mask=0can therefore make valid normalized advantages non-finite and propagate into actor loss, gradients, and the optimizer update without a local exception.This patch masks by selection before arithmetic in
Normalization: masked entries are replaced with zero viatorch.where(...)before mean, std, and final centering reductions. Valid-token behavior is unchanged, while invalid masked values no longer enter reward or advantage normalization.Concrete Triggering Example
Minimal trigger data:
The invalid slot is row 0, column 1. It is masked out with
loss_mask=0, so it should not affect the valid normalization set[1.0, 3.0, 5.0].Wrong unpatched values:
Fixed values:
Reproduction Recipe
{ "kind": "rl_sentinel_validation_recipe", "schema_version": 1, "bug_id": "AREAL-NORM-MASKED-INVALID-VALUES", "target": "areal", "validation_mode": "real_target_boundary_hook", "target_boundaries": [ "areal.utils.data.Normalization._compute_mean", "areal.utils.data.Normalization._compute_std", "areal.utils.functional.ppo_actor_loss_fn" ], "requires_model_download": false, "requires_dataset_download": false, "constructed_scenario": { "normalization_config": { "mean_level": "batch", "mean_leave1out": false, "std_level": "batch", "std_unbiased": false, "group_size": 1, "eps": 1e-05 }, "loss_mask": [[1, 0], [1, 1]], "clean_advantages": [[1.0, 0.0], [3.0, 5.0]], "contaminated_advantages": [[1.0, "nan"], [3.0, 5.0]], "invalid_slot": {"row": 0, "col": 1, "loss_mask": 0} }, "expected_unpatched": { "status": "confirmed", "valid_normalized_advantages_finite": false, "actor_loss_finite": false, "grad_norm_finite": false, "update_delta_norm_finite": false }, "expected_fixed": { "status": "not_triggered", "valid_normalized_advantages": [-1.2247374057769775, 0.0, 1.2247374057769775], "valid_normalized_advantages_finite": true, "actor_loss": 0.0, "grad_norm": 0.5773467421531677, "update_delta_norm": 0.05773467570543289 } }Validation Runner
Run from the AReaL repository root:
Observed Output
Unpatched output:
{ "valid_normalized_advantages": ["nan", "nan", "nan"], "valid_normalized_advantages_finite": false, "actor_loss": "nan", "actor_loss_finite": false, "grad_norm": "nan", "grad_norm_finite": false, "update_delta_norm": "nan", "update_delta_norm_finite": false }Fixed output at commit
81d17d16928254677e03e134fbd3a797fd1916f3:{ "valid_normalized_advantages": [-1.2247374057769775, 0.0, 1.2247374057769775], "valid_normalized_advantages_finite": true, "actor_loss": 0.0, "actor_loss_finite": true, "grad_norm": 0.5773467421531677, "grad_norm_finite": true, "update_delta_norm": 0.05773467570543289, "update_delta_norm_finite": true }Root Cause
Normalization._compute_mean()and_compute_std()used mask multiplication:The final centering step also used:
Mask multiplication is not a valid invalid-value guard. In PyTorch,
NaN * 0remainsNaN, andInf * 0can also produceNaN. Sincereward_normandadv_normreuse this class before PPO actor loss, a masked invalid value can corrupt valid-token normalization and the policy update.Fix
The patch replaces mask multiplication with selection before arithmetic:
This keeps the same valid-token denominator and normalized values, while guaranteeing invalid masked slots are zero before reductions or final centered values are formed.
Tests And Checks
Added regression coverage:
Checks run:
Results:
Targeted pytest was run but is blocked in this local environment before collecting the selected tests:
Contribution And Duplicate Checks
Contribution checks:
Duplicate checks performed:
Related PRs Or Fixes
areal-project/AReaL#394added leave-one-out mean and unbiased std support for normalization. It is related code, but it is not a masked invalid-value bugfix.areal-project/AReaL#501added GSPO support and sequence-level advantage computation. It is related to policy optimization, but not toNormalization._compute_mean()or_compute_std().AREAL-GSPO-2D-MASKED-ADV-LEAKfixes sequence-level PPO/GSPO averaging inppo_actor_loss_fn, not reward/advantage normalization.AREAL-KL-ADV-MASKED-LOGPROB-NANfixes PPOActor KL reward construction from old/ref logprobs, notreward_normoradv_norm.