Skip to content

fix(utils): ignore masked invalid normalization values#1347

Open
haoyang9804 wants to merge 9 commits into
areal-project:mainfrom
haoyang9804:fix/norm-mask-invalid-values
Open

fix(utils): ignore masked invalid normalization values#1347
haoyang9804 wants to merge 9 commits into
areal-project:mainfrom
haoyang9804:fix/norm-mask-invalid-values

Conversation

@haoyang9804
Copy link
Copy Markdown
Contributor

Summary

AReaL's Normalization path for reward_norm and adv_norm can let an invalid value from a masked-out token corrupt every valid normalized reward or advantage. The trigger is quiet: Normalization._compute_mean() and _compute_std() multiply tensors by loss_mask, but NaN * 0 and Inf * 0 remain non-finite. A single invalid value where loss_mask=0 can therefore make valid normalized advantages non-finite and propagate into actor loss, gradients, and the optimizer update without a local exception.

This patch masks by selection before arithmetic in Normalization: masked entries are replaced with zero via torch.where(...) before mean, std, and final centering reductions. Valid-token behavior is unchanged, while invalid masked values no longer enter reward or advantage normalization.

Concrete Triggering Example

Minimal trigger data:

loss_mask = torch.tensor([[1.0, 0.0], [1.0, 1.0]])
advantages = torch.tensor([[1.0, float("nan")], [3.0, 5.0]])

The invalid slot is row 0, column 1. It is masked out with loss_mask=0, so it should not affect the valid normalization set [1.0, 3.0, 5.0].

Wrong unpatched values:

valid_normalized_advantages = [nan, nan, nan]
actor_loss = nan
grad_norm = nan
update_delta_norm = nan

Fixed values:

valid_normalized_advantages = [-1.2247374057769775, 0.0, 1.2247374057769775]
actor_loss = 0.0
grad_norm = 0.5773467421531677
update_delta_norm = 0.05773467570543289

Reproduction Recipe

{
  "kind": "rl_sentinel_validation_recipe",
  "schema_version": 1,
  "bug_id": "AREAL-NORM-MASKED-INVALID-VALUES",
  "target": "areal",
  "validation_mode": "real_target_boundary_hook",
  "target_boundaries": [
    "areal.utils.data.Normalization._compute_mean",
    "areal.utils.data.Normalization._compute_std",
    "areal.utils.functional.ppo_actor_loss_fn"
  ],
  "requires_model_download": false,
  "requires_dataset_download": false,
  "constructed_scenario": {
    "normalization_config": {
      "mean_level": "batch",
      "mean_leave1out": false,
      "std_level": "batch",
      "std_unbiased": false,
      "group_size": 1,
      "eps": 1e-05
    },
    "loss_mask": [[1, 0], [1, 1]],
    "clean_advantages": [[1.0, 0.0], [3.0, 5.0]],
    "contaminated_advantages": [[1.0, "nan"], [3.0, 5.0]],
    "invalid_slot": {"row": 0, "col": 1, "loss_mask": 0}
  },
  "expected_unpatched": {
    "status": "confirmed",
    "valid_normalized_advantages_finite": false,
    "actor_loss_finite": false,
    "grad_norm_finite": false,
    "update_delta_norm_finite": false
  },
  "expected_fixed": {
    "status": "not_triggered",
    "valid_normalized_advantages": [-1.2247374057769775, 0.0, 1.2247374057769775],
    "valid_normalized_advantages_finite": true,
    "actor_loss": 0.0,
    "grad_norm": 0.5773467421531677,
    "update_delta_norm": 0.05773467570543289
  }
}

Validation Runner

Run from the AReaL repository root:

#!/usr/bin/env bash
set -euo pipefail

TARGET_REPO="${TARGET_REPO:-$(pwd)}"
OUTPUT_DIR="${OUTPUT_DIR:-./areal_norm_mask_validation}"
PYTHON="${PYTHON:-python3}"
mkdir -p "${OUTPUT_DIR}"
export TARGET_REPO OUTPUT_DIR
export PYTHONPATH="${TARGET_REPO}${PYTHONPATH:+:${PYTHONPATH}}"

"${PYTHON}" <<'PY'
import importlib
import json
import math
import os
import subprocess
import sys
from pathlib import Path
from types import SimpleNamespace

import torch

target_repo = Path(os.environ["TARGET_REPO"]).resolve()
output_dir = Path(os.environ["OUTPUT_DIR"]).resolve()

try:
    importlib.import_module("areal")
except Exception:
    pass

data_mod = importlib.import_module("areal.utils.data")
functional_mod = importlib.import_module("areal.utils.functional")
Normalization = data_mod.Normalization
ppo_actor_loss_fn = functional_mod.ppo_actor_loss_fn

for name, module in {
    "areal.utils.data": data_mod,
    "areal.utils.functional": functional_mod,
}.items():
    path = Path(module.__file__).resolve()
    if not str(path).startswith(str(target_repo)):
        raise RuntimeError(f"{name} imported from {path}, outside {target_repo}")

config = SimpleNamespace(
    mean_level="batch",
    mean_leave1out=False,
    std_level="batch",
    std_unbiased=False,
    group_size=1,
    eps=1e-5,
)
loss_mask = torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32)
mask_bool = loss_mask.bool()
advantages = torch.tensor([[1.0, float("nan")], [3.0, 5.0]], dtype=torch.float32)

normalized = Normalization(config)(advantages, loss_mask)

logprobs = torch.zeros_like(normalized, requires_grad=True)
loss, _ = ppo_actor_loss_fn(
    logprobs=logprobs,
    proximal_logprobs=torch.zeros_like(normalized),
    old_logprobs=torch.zeros_like(normalized),
    advantages=normalized,
    eps_clip=0.2,
    loss_mask=mask_bool,
    importance_sampling_level="token",
)
loss.backward()
grad = logprobs.grad.detach()
update = -0.1 * grad

def scalar(x):
    if torch.is_tensor(x):
        x = x.item()
    if isinstance(x, float) and (math.isnan(x) or math.isinf(x)):
        return str(x)
    return float(x)

def tensor_list(t):
    out = []
    for v in t.detach().cpu().reshape(-1).tolist():
        if isinstance(v, float) and (math.isnan(v) or math.isinf(v)):
            out.append(str(v))
        else:
            out.append(float(v))
    return out

result = {
    "target_commit": subprocess.check_output(
        ["git", "-C", str(target_repo), "rev-parse", "HEAD"], text=True
    ).strip(),
    "valid_normalized_advantages": tensor_list(normalized[mask_bool]),
    "valid_normalized_advantages_finite": bool(
        torch.isfinite(normalized[mask_bool]).all().item()
    ),
    "actor_loss": scalar(loss.detach()),
    "actor_loss_finite": bool(torch.isfinite(loss.detach()).item()),
    "grad_norm": scalar(torch.linalg.vector_norm(grad[mask_bool])),
    "grad_norm_finite": bool(torch.isfinite(grad[mask_bool]).all().item()),
    "update_delta_norm": scalar(torch.linalg.vector_norm(update[mask_bool])),
    "update_delta_norm_finite": bool(torch.isfinite(update[mask_bool]).all().item()),
}
print(json.dumps(result, indent=2, sort_keys=True))
PY

Observed Output

Unpatched output:

{
  "valid_normalized_advantages": ["nan", "nan", "nan"],
  "valid_normalized_advantages_finite": false,
  "actor_loss": "nan",
  "actor_loss_finite": false,
  "grad_norm": "nan",
  "grad_norm_finite": false,
  "update_delta_norm": "nan",
  "update_delta_norm_finite": false
}

Fixed output at commit 81d17d16928254677e03e134fbd3a797fd1916f3:

{
  "valid_normalized_advantages": [-1.2247374057769775, 0.0, 1.2247374057769775],
  "valid_normalized_advantages_finite": true,
  "actor_loss": 0.0,
  "actor_loss_finite": true,
  "grad_norm": 0.5773467421531677,
  "grad_norm_finite": true,
  "update_delta_norm": 0.05773467570543289,
  "update_delta_norm_finite": true
}

Root Cause

Normalization._compute_mean() and _compute_std() used mask multiplication:

x_masked = x * mask
x_centered = x_masked - mean * mask

The final centering step also used:

x_centered = x_centered * loss_mask

Mask multiplication is not a valid invalid-value guard. In PyTorch, NaN * 0 remains NaN, and Inf * 0 can also produce NaN. Since reward_norm and adv_norm reuse this class before PPO actor loss, a masked invalid value can corrupt valid-token normalization and the policy update.

Fix

The patch replaces mask multiplication with selection before arithmetic:

x_masked = torch.where(mask.bool(), x * mask, torch.zeros_like(x))
x_centered = torch.where(mask.bool(), (x - mean) * mask, torch.zeros_like(x))
x_centered = torch.where(loss_mask.bool(), x_centered, 0.0)

This keeps the same valid-token denominator and normalized values, while guaranteeing invalid masked slots are zero before reductions or final centered values are formed.

Tests And Checks

Added regression coverage:

tests/test_adv_norm_config.py::test_masked_invalid_values_do_not_poison_batch_normalization
tests/test_adv_norm_config.py::test_masked_invalid_values_do_not_poison_group_normalization

Checks run:

git diff --check HEAD^ HEAD -- areal/utils/data.py tests/test_adv_norm_config.py
python3 -m ruff check areal/utils/data.py tests/test_adv_norm_config.py
python3 -m ruff format --check areal/utils/data.py tests/test_adv_norm_config.py
pre-commit install --install-hooks
pre-commit run --files areal/utils/data.py tests/test_adv_norm_config.py

Results:

diff_check: passed
ruff_check: passed
ruff_format_check: passed
pre-commit install --install-hooks: passed
pre-commit run --files areal/utils/data.py tests/test_adv_norm_config.py: passed
fixed validation hook: not_triggered

Targeted pytest was run but is blocked in this local environment before collecting the selected tests:

python3 -m pytest -q tests/test_adv_norm_config.py::test_masked_invalid_values_do_not_poison_batch_normalization tests/test_adv_norm_config.py::test_masked_invalid_values_do_not_poison_group_normalization

ImportError: cannot import name 'DefaultStager' from 'torch.distributed.checkpoint.staging'

Contribution And Duplicate Checks

Contribution checks:

Read CONTRIBUTING.md from AReaL HEAD.
Used a Conventional Commit message: fix(utils): ignore masked invalid normalization values.
Ran pre-commit install --install-hooks.
Ran pre-commit run --files on the changed files.
Ran ruff check and ruff format --check on the changed files.

Duplicate checks performed:

BUG_FINDINGS.md: no exact AReaL Normalization reward_norm/adv_norm masked invalid-value entry.
myAReal local branches: fix/gspo-2d-masked-advantages and fix/ppo-adv-mask-invalid-logprobs checked.
myAReal remote branches: origin/fix/gspo-2d-masked-advantages checked.
myAReal pr_drafts: existing GSPO and PPOActor KL drafts checked.
Historical RL-Sentinel artifacts: searched for AREAL-NORM, Normalization, adv_norm, reward_norm, _compute_mean, and _compute_std.
Upstream areal-project/AReaL issues and PRs: searched by Normalization, adv_norm, reward_norm, masked NaN, _compute_mean, and _compute_std.

Related PRs Or Fixes

  • areal-project/AReaL#394 added leave-one-out mean and unbiased std support for normalization. It is related code, but it is not a masked invalid-value bugfix.
  • areal-project/AReaL#501 added GSPO support and sequence-level advantage computation. It is related to policy optimization, but not to Normalization._compute_mean() or _compute_std().
  • AREAL-GSPO-2D-MASKED-ADV-LEAK fixes sequence-level PPO/GSPO averaging in ppo_actor_loss_fn, not reward/advantage normalization.
  • AREAL-KL-ADV-MASKED-LOGPROB-NAN fixes PPOActor KL reward construction from old/ref logprobs, not reward_norm or adv_norm.

Copy link
Copy Markdown
Collaborator

@fishcrap fishcrap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good fix for a real silent-corruption bug. Two inline suggestions for robustness.

Copy link
Copy Markdown
Collaborator

@fishcrap fishcrap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NaN-intermediate issue is resolved. Code is correct. Two inline suggestions for robustness.

Comment thread areal/utils/data.py Outdated
Comment thread areal/utils/data.py Outdated
x_masked = x * mask
factor = mask.sum(dim, keepdim=True)
x_centered = x_masked - mean * mask # only apply mean where mask is 1
x_safe = torch.where(mask.bool(), x, 0.0)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment thread areal/utils/data.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR hardens Normalization so masked-out non-finite values do not contaminate masked mean/std reductions or final normalized outputs.

Changes:

  • Replaces mask multiplication with torch.where selection in normalization mean/std paths.
  • Ensures masked positions are zeroed during centering when loss_mask is provided.
  • Adds regression tests for masked NaN/Inf values in batch and group normalization.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
areal/utils/data.py Updates Normalization masking logic to ignore invalid values outside the loss mask.
tests/test_adv_norm_config.py Adds regression coverage for masked invalid values under batch and group normalization.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Collaborator

@fishcrap fishcrap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_compute_mean and _compute_std look good now. One remaining inconsistency in __call__.

Comment thread areal/utils/data.py
@fishcrap fishcrap added the safe-to-test Ready to run unit-tests in a PR. label May 19, 2026
@haoyang9804
Copy link
Copy Markdown
Contributor Author

seems that Google Could auth failed in CI/CD:

Error: google-github-actions/auth failed with: the GitHub Action workflow must specify exactly one of "workload_identity_provider" or "credentials_json"! If you are specifying input values via GitHub secrets, ensure the secret is being injected into the environment. By default, secrets are not passed to workflows triggered from forks, including Dependabot.

@fishcrap fishcrap added safe-to-test Ready to run unit-tests in a PR. and removed safe-to-test Ready to run unit-tests in a PR. labels May 19, 2026
Copy link
Copy Markdown
Collaborator

@fishcrap fishcrap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch.where approach silently replaces NaN with 0, masking upstream issues (model divergence, -inf logprobs, etc.) that should be surfaced. Suggest logging a warning instead of hiding the problem.

@haoyang9804
Copy link
Copy Markdown
Contributor Author

haoyang9804 commented May 19, 2026

The torch.where approach silently replaces NaN with 0, masking upstream issues (model divergence, -inf logprobs, etc.) that should be surfaced. Suggest logging a warning instead of hiding the problem.

I guess there is a gap. torch.where only replaces values where loss_mask is false. If an active position contains NaN/Inf, it is still selected and still propagates through normalization, so upstream divergence remains visible. Do you think we should also expose NaN/inf even if loss_mask is false?

@fishcrap
Copy link
Copy Markdown
Collaborator

The torch.where approach silently replaces NaN with 0, masking upstream issues (model divergence, -inf logprobs, etc.) that should be surfaced. Suggest logging a warning instead of hiding the problem.

I guess there is a gap. torch.where only replaces values where loss_mask is false. If an active position contains NaN/Inf, it is still selected and still propagates through normalization, so upstream divergence remains visible. Do you think we should also expose NaN/inf even if loss_mask is false?

Thanks for the fix — the torch.where approach is the correct solution for the NaN * 0 = NaN issue.

However, the worry isn't that torch.where hides NaN at active positions. The concern is: if NaN/Inf only appears at masked positions, the training will look perfectly normal, but the model is already producing non-finite values somewhere upstream (e.g., logprob computation, reward, or forward pass). Today it's only in masked positions; tomorrow it could spread to active ones, and by then it may be too late to diagnose.

So I think we should keep the torch.where fix as-is (it's correct), but also add a lightweight warning at the entry of call to surface the issue early. This way we fix the bug without losing visibility into potential upstream problems. What do you think?

@haoyang9804
Copy link
Copy Markdown
Contributor Author

haoyang9804 commented May 20, 2026

The torch.where approach silently replaces NaN with 0, masking upstream issues (model divergence, -inf logprobs, etc.) that should be surfaced. Suggest logging a warning instead of hiding the problem.

I guess there is a gap. torch.where only replaces values where loss_mask is false. If an active position contains NaN/Inf, it is still selected and still propagates through normalization, so upstream divergence remains visible. Do you think we should also expose NaN/inf even if loss_mask is false?

Thanks for the fix — the torch.where approach is the correct solution for the NaN * 0 = NaN issue.

However, the worry isn't that torch.where hides NaN at active positions. The concern is: if NaN/Inf only appears at masked positions, the training will look perfectly normal, but the model is already producing non-finite values somewhere upstream (e.g., logprob computation, reward, or forward pass). Today it's only in masked positions; tomorrow it could spread to active ones, and by then it may be too late to diagnose.

So I think we should keep the torch.where fix as-is (it's correct), but also add a lightweight warning at the entry of call to surface the issue early. This way we fix the bug without losing visibility into potential upstream problems. What do you think?

no problem. let me add a warning. And many thanks for your careful review!! I sincerely learned a lot.
UPDATE: just pushed the warning :) please take a look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

safe-to-test Ready to run unit-tests in a PR.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants