Skip to content
43 changes: 37 additions & 6 deletions areal/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,38 @@ def __call__(
bs = x.size(0)
eps = self.eps

non_finite = ~torch.isfinite(x)
if non_finite.any().item():
if loss_mask is None:
logger.warning(
"Normalization input contains non-finite values and no "
"loss_mask was provided. They will propagate through "
"normalization and may indicate an upstream numerical issue."
)
else:
active_non_finite = loss_mask.bool().logical_and(non_finite)
masked_non_finite = (~loss_mask.bool()).logical_and(non_finite)
if active_non_finite.any().item() and masked_non_finite.any().item():
logger.warning(
"Normalization input contains non-finite values at both "
"active and masked positions. Active non-finite values "
"will propagate through normalization; masked non-finite "
"values will be ignored by loss_mask. This may indicate "
"an upstream numerical issue."
)
elif active_non_finite.any().item():
logger.warning(
"Normalization input contains non-finite values at active "
"positions. They will propagate through normalization and "
"may indicate an upstream numerical issue."
)
else:
logger.warning(
"Normalization input contains non-finite values at masked "
"positions. They will be ignored by loss_mask, but this "
"may indicate an upstream numerical issue."
)

# Early return if no elements are active (all masked out)
if loss_mask is not None and loss_mask.sum().item() == 0:
return x.float()
Expand Down Expand Up @@ -1446,10 +1478,10 @@ def __call__(
mean = torch.zeros_like(x)

# Subtract mean
x_centered = x - mean
# mask unrelevant elements as 0
if loss_mask is not None:
x_centered = x_centered * loss_mask
x_centered = torch.where(loss_mask.bool(), x - mean, 0.0)
else:
x_centered = x - mean
Comment thread
fishcrap marked this conversation as resolved.

# Step 2: Compute std
if self.std_level == "batch":
Expand Down Expand Up @@ -1517,7 +1549,7 @@ def _compute_mean(
x_sum = x.sum(dim=dim, keepdim=True)
else:
mask = mask.to(dtype)
x_masked = x * mask
x_masked = torch.where(mask.bool(), x, 0.0)
factor = mask.sum(dim, keepdim=True)
x_sum = x_masked.sum(dim=dim, keepdim=True)

Expand Down Expand Up @@ -1576,9 +1608,8 @@ def _compute_std(
x_sum_sq = (x_centered**2).sum(dim=dim, keepdim=True)
else:
mask = mask.to(dtype)
x_masked = x * mask
factor = mask.sum(dim, keepdim=True)
x_centered = x_masked - mean * mask # only apply mean where mask is 1
x_centered = torch.where(mask.bool(), x - mean, 0.0)
x_sum_sq = (x_centered**2).sum(dim=dim, keepdim=True)

if dist.is_initialized() and all_reduce:
Expand Down
93 changes: 93 additions & 0 deletions tests/test_adv_norm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,99 @@ def test_non_trivial_loss_mask_batch_normalization():
assert torch.abs(non_masked_values.std() - 1.0) < 1e-5


def test_masked_invalid_values_do_not_poison_batch_normalization():
config = NormConfig(mean_level="batch", std_level="batch", group_size=1)
adv_norm = Normalization(config)

advantages = torch.tensor(
[[1.0, float("nan")], [3.0, 5.0]],
dtype=torch.float32,
)
loss_mask = torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32)

normalized = adv_norm(advantages, loss_mask)

expected_valid = torch.tensor(
[-1.0, 0.0, 1.0],
dtype=torch.float32,
)
assert torch.isfinite(normalized).all()
assert torch.allclose(normalized[loss_mask.bool()], expected_valid, atol=1e-6)
assert normalized[0, 1].item() == 0.0


def test_masked_invalid_values_emit_warning():
config = NormConfig(mean_level="batch", std_level="batch", group_size=1)
adv_norm = Normalization(config)

advantages = torch.tensor(
[[1.0, float("nan")], [3.0, 5.0]],
dtype=torch.float32,
)
loss_mask = torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32)

with patch("areal.utils.data.logger.warning") as warning:
normalized = adv_norm(advantages, loss_mask)

warning.assert_called_once()
assert "non-finite values at masked positions" in warning.call_args.args[0]
assert torch.isfinite(normalized).all()


def test_masked_invalid_values_do_not_poison_group_normalization():
config = NormConfig(mean_level="group", std_level="group", group_size=2)
adv_norm = Normalization(config)

advantages = torch.tensor(
[[1.0, float("inf")], [3.0, 5.0]],
dtype=torch.float32,
)
loss_mask = torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32)

normalized = adv_norm(advantages, loss_mask)

expected_valid = torch.tensor(
[-1.0, 0.0, 1.0],
dtype=torch.float32,
)
assert torch.isfinite(normalized).all()
assert torch.allclose(normalized[loss_mask.bool()], expected_valid, atol=1e-6)
assert normalized[0, 1].item() == 0.0


def test_unmasked_invalid_values_are_not_sanitized():
config = NormConfig(mean_level="batch", std_level=None, group_size=1)
adv_norm = Normalization(config)

advantages = torch.tensor(
[[1.0, float("nan")], [3.0, 5.0]],
dtype=torch.float32,
)
loss_mask = torch.ones_like(advantages)

normalized = adv_norm(advantages, loss_mask)

assert torch.isnan(normalized).any()


def test_unmasked_invalid_values_emit_warning_and_are_not_sanitized():
config = NormConfig(mean_level="batch", std_level=None, group_size=1)
adv_norm = Normalization(config)

advantages = torch.tensor(
[[1.0, float("nan")], [3.0, 5.0]],
dtype=torch.float32,
)
loss_mask = torch.ones_like(advantages)

with patch("areal.utils.data.logger.warning") as warning:
normalized = adv_norm(advantages, loss_mask)

warning.assert_called_once()
assert "non-finite values at active positions" in warning.call_args.args[0]
assert torch.isnan(normalized).any()


def test_non_trivial_loss_mask_leave_one_out():
"""Test leave-one-out normalization with non-trivial loss mask and verify expected values."""
config = NormConfig(
Expand Down
Loading