diff --git a/areal/utils/data.py b/areal/utils/data.py index 09368e4ef..bf25231de 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -1404,6 +1404,38 @@ def __call__( bs = x.size(0) eps = self.eps + non_finite = ~torch.isfinite(x) + if non_finite.any().item(): + if loss_mask is None: + logger.warning( + "Normalization input contains non-finite values and no " + "loss_mask was provided. They will propagate through " + "normalization and may indicate an upstream numerical issue." + ) + else: + active_non_finite = loss_mask.bool().logical_and(non_finite) + masked_non_finite = (~loss_mask.bool()).logical_and(non_finite) + if active_non_finite.any().item() and masked_non_finite.any().item(): + logger.warning( + "Normalization input contains non-finite values at both " + "active and masked positions. Active non-finite values " + "will propagate through normalization; masked non-finite " + "values will be ignored by loss_mask. This may indicate " + "an upstream numerical issue." + ) + elif active_non_finite.any().item(): + logger.warning( + "Normalization input contains non-finite values at active " + "positions. They will propagate through normalization and " + "may indicate an upstream numerical issue." + ) + else: + logger.warning( + "Normalization input contains non-finite values at masked " + "positions. They will be ignored by loss_mask, but this " + "may indicate an upstream numerical issue." + ) + # Early return if no elements are active (all masked out) if loss_mask is not None and loss_mask.sum().item() == 0: return x.float() @@ -1446,10 +1478,10 @@ def __call__( mean = torch.zeros_like(x) # Subtract mean - x_centered = x - mean - # mask unrelevant elements as 0 if loss_mask is not None: - x_centered = x_centered * loss_mask + x_centered = torch.where(loss_mask.bool(), x - mean, 0.0) + else: + x_centered = x - mean # Step 2: Compute std if self.std_level == "batch": @@ -1517,7 +1549,7 @@ def _compute_mean( x_sum = x.sum(dim=dim, keepdim=True) else: mask = mask.to(dtype) - x_masked = x * mask + x_masked = torch.where(mask.bool(), x, 0.0) factor = mask.sum(dim, keepdim=True) x_sum = x_masked.sum(dim=dim, keepdim=True) @@ -1576,9 +1608,8 @@ def _compute_std( x_sum_sq = (x_centered**2).sum(dim=dim, keepdim=True) else: mask = mask.to(dtype) - x_masked = x * mask factor = mask.sum(dim, keepdim=True) - x_centered = x_masked - mean * mask # only apply mean where mask is 1 + x_centered = torch.where(mask.bool(), x - mean, 0.0) x_sum_sq = (x_centered**2).sum(dim=dim, keepdim=True) if dist.is_initialized() and all_reduce: diff --git a/tests/test_adv_norm_config.py b/tests/test_adv_norm_config.py index fb1553fee..76cf85da5 100644 --- a/tests/test_adv_norm_config.py +++ b/tests/test_adv_norm_config.py @@ -1104,6 +1104,99 @@ def test_non_trivial_loss_mask_batch_normalization(): assert torch.abs(non_masked_values.std() - 1.0) < 1e-5 +def test_masked_invalid_values_do_not_poison_batch_normalization(): + config = NormConfig(mean_level="batch", std_level="batch", group_size=1) + adv_norm = Normalization(config) + + advantages = torch.tensor( + [[1.0, float("nan")], [3.0, 5.0]], + dtype=torch.float32, + ) + loss_mask = torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32) + + normalized = adv_norm(advantages, loss_mask) + + expected_valid = torch.tensor( + [-1.0, 0.0, 1.0], + dtype=torch.float32, + ) + assert torch.isfinite(normalized).all() + assert torch.allclose(normalized[loss_mask.bool()], expected_valid, atol=1e-6) + assert normalized[0, 1].item() == 0.0 + + +def test_masked_invalid_values_emit_warning(): + config = NormConfig(mean_level="batch", std_level="batch", group_size=1) + adv_norm = Normalization(config) + + advantages = torch.tensor( + [[1.0, float("nan")], [3.0, 5.0]], + dtype=torch.float32, + ) + loss_mask = torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32) + + with patch("areal.utils.data.logger.warning") as warning: + normalized = adv_norm(advantages, loss_mask) + + warning.assert_called_once() + assert "non-finite values at masked positions" in warning.call_args.args[0] + assert torch.isfinite(normalized).all() + + +def test_masked_invalid_values_do_not_poison_group_normalization(): + config = NormConfig(mean_level="group", std_level="group", group_size=2) + adv_norm = Normalization(config) + + advantages = torch.tensor( + [[1.0, float("inf")], [3.0, 5.0]], + dtype=torch.float32, + ) + loss_mask = torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32) + + normalized = adv_norm(advantages, loss_mask) + + expected_valid = torch.tensor( + [-1.0, 0.0, 1.0], + dtype=torch.float32, + ) + assert torch.isfinite(normalized).all() + assert torch.allclose(normalized[loss_mask.bool()], expected_valid, atol=1e-6) + assert normalized[0, 1].item() == 0.0 + + +def test_unmasked_invalid_values_are_not_sanitized(): + config = NormConfig(mean_level="batch", std_level=None, group_size=1) + adv_norm = Normalization(config) + + advantages = torch.tensor( + [[1.0, float("nan")], [3.0, 5.0]], + dtype=torch.float32, + ) + loss_mask = torch.ones_like(advantages) + + normalized = adv_norm(advantages, loss_mask) + + assert torch.isnan(normalized).any() + + +def test_unmasked_invalid_values_emit_warning_and_are_not_sanitized(): + config = NormConfig(mean_level="batch", std_level=None, group_size=1) + adv_norm = Normalization(config) + + advantages = torch.tensor( + [[1.0, float("nan")], [3.0, 5.0]], + dtype=torch.float32, + ) + loss_mask = torch.ones_like(advantages) + + with patch("areal.utils.data.logger.warning") as warning: + normalized = adv_norm(advantages, loss_mask) + + warning.assert_called_once() + assert "non-finite values at active positions" in warning.call_args.args[0] + assert torch.isnan(normalized).any() + + def test_non_trivial_loss_mask_leave_one_out(): """Test leave-one-out normalization with non-trivial loss mask and verify expected values.""" config = NormConfig(