From 81d17d16928254677e03e134fbd3a797fd1916f3 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Sun, 17 May 2026 23:55:29 +0800 Subject: [PATCH 1/7] fix(utils): ignore masked invalid normalization values --- areal/utils/data.py | 11 +++++---- tests/test_adv_norm_config.py | 42 +++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/areal/utils/data.py b/areal/utils/data.py index 09368e4ef2..eafc22fad4 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -1449,7 +1449,7 @@ def __call__( x_centered = x - mean # mask unrelevant elements as 0 if loss_mask is not None: - x_centered = x_centered * loss_mask + x_centered = torch.where(loss_mask.bool(), x_centered, 0.0) # Step 2: Compute std if self.std_level == "batch": @@ -1517,7 +1517,7 @@ def _compute_mean( x_sum = x.sum(dim=dim, keepdim=True) else: mask = mask.to(dtype) - x_masked = x * mask + x_masked = torch.where(mask.bool(), x * mask, torch.zeros_like(x)) factor = mask.sum(dim, keepdim=True) x_sum = x_masked.sum(dim=dim, keepdim=True) @@ -1576,9 +1576,12 @@ def _compute_std( x_sum_sq = (x_centered**2).sum(dim=dim, keepdim=True) else: mask = mask.to(dtype) - x_masked = x * mask factor = mask.sum(dim, keepdim=True) - x_centered = x_masked - mean * mask # only apply mean where mask is 1 + x_centered = torch.where( + mask.bool(), + (x - mean) * mask, + torch.zeros_like(x), + ) x_sum_sq = (x_centered**2).sum(dim=dim, keepdim=True) if dist.is_initialized() and all_reduce: diff --git a/tests/test_adv_norm_config.py b/tests/test_adv_norm_config.py index fb1553fee4..b55a7b74a4 100644 --- a/tests/test_adv_norm_config.py +++ b/tests/test_adv_norm_config.py @@ -1104,6 +1104,48 @@ def test_non_trivial_loss_mask_batch_normalization(): assert torch.abs(non_masked_values.std() - 1.0) < 1e-5 +def test_masked_invalid_values_do_not_poison_batch_normalization(): + config = NormConfig(mean_level="batch", std_level="batch", group_size=1) + adv_norm = Normalization(config) + + advantages = torch.tensor( + [[1.0, float("nan")], [3.0, 5.0]], + dtype=torch.float32, + ) + loss_mask = torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32) + + normalized = adv_norm(advantages, loss_mask) + + expected_valid = torch.tensor( + [-1.2247374, 0.0, 1.2247374], + dtype=torch.float32, + ) + assert torch.isfinite(normalized).all() + assert torch.allclose(normalized[loss_mask.bool()], expected_valid, atol=1e-6) + assert normalized[0, 1].item() == 0.0 + + +def test_masked_invalid_values_do_not_poison_group_normalization(): + config = NormConfig(mean_level="group", std_level="group", group_size=2) + adv_norm = Normalization(config) + + advantages = torch.tensor( + [[1.0, float("inf")], [3.0, 5.0]], + dtype=torch.float32, + ) + loss_mask = torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32) + + normalized = adv_norm(advantages, loss_mask) + + expected_valid = torch.tensor( + [-1.2247374, 0.0, 1.2247374], + dtype=torch.float32, + ) + assert torch.isfinite(normalized).all() + assert torch.allclose(normalized[loss_mask.bool()], expected_valid, atol=1e-6) + assert normalized[0, 1].item() == 0.0 + + def test_non_trivial_loss_mask_leave_one_out(): """Test leave-one-out normalization with non-trivial loss mask and verify expected values.""" config = NormConfig( From e2dae9ddd98a0e312c2a39624062ec4412e815a0 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Tue, 19 May 2026 00:37:37 +0800 Subject: [PATCH 2/7] fix(utils): avoid masked normalization NaN intermediates --- areal/utils/data.py | 17 ++++++++--------- tests/test_adv_norm_config.py | 4 ++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/areal/utils/data.py b/areal/utils/data.py index eafc22fad4..488f3e71d5 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -1446,10 +1446,11 @@ def __call__( mean = torch.zeros_like(x) # Subtract mean - x_centered = x - mean - # mask unrelevant elements as 0 if loss_mask is not None: - x_centered = torch.where(loss_mask.bool(), x_centered, 0.0) + x_safe = torch.where(loss_mask.bool(), x, 0.0) + x_centered = (x_safe - mean) * loss_mask + else: + x_centered = x - mean # Step 2: Compute std if self.std_level == "batch": @@ -1517,7 +1518,8 @@ def _compute_mean( x_sum = x.sum(dim=dim, keepdim=True) else: mask = mask.to(dtype) - x_masked = torch.where(mask.bool(), x * mask, torch.zeros_like(x)) + x_safe = torch.where(mask.bool(), x, 0.0) + x_masked = x_safe * mask factor = mask.sum(dim, keepdim=True) x_sum = x_masked.sum(dim=dim, keepdim=True) @@ -1577,11 +1579,8 @@ def _compute_std( else: mask = mask.to(dtype) factor = mask.sum(dim, keepdim=True) - x_centered = torch.where( - mask.bool(), - (x - mean) * mask, - torch.zeros_like(x), - ) + x_safe = torch.where(mask.bool(), x, 0.0) + x_centered = (x_safe - mean) * mask x_sum_sq = (x_centered**2).sum(dim=dim, keepdim=True) if dist.is_initialized() and all_reduce: diff --git a/tests/test_adv_norm_config.py b/tests/test_adv_norm_config.py index b55a7b74a4..b89c8e3712 100644 --- a/tests/test_adv_norm_config.py +++ b/tests/test_adv_norm_config.py @@ -1117,7 +1117,7 @@ def test_masked_invalid_values_do_not_poison_batch_normalization(): normalized = adv_norm(advantages, loss_mask) expected_valid = torch.tensor( - [-1.2247374, 0.0, 1.2247374], + [-1.0, 0.0, 1.0], dtype=torch.float32, ) assert torch.isfinite(normalized).all() @@ -1138,7 +1138,7 @@ def test_masked_invalid_values_do_not_poison_group_normalization(): normalized = adv_norm(advantages, loss_mask) expected_valid = torch.tensor( - [-1.2247374, 0.0, 1.2247374], + [-1.0, 0.0, 1.0], dtype=torch.float32, ) assert torch.isfinite(normalized).all() From 8e82c625fa194fffff155e9b6bc8ff9405290be2 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Tue, 19 May 2026 00:46:06 +0800 Subject: [PATCH 3/7] fix(utils): simplify masked normalization mean --- areal/utils/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/areal/utils/data.py b/areal/utils/data.py index 488f3e71d5..cf38633db9 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -1519,7 +1519,7 @@ def _compute_mean( else: mask = mask.to(dtype) x_safe = torch.where(mask.bool(), x, 0.0) - x_masked = x_safe * mask + x_masked = x_safe factor = mask.sum(dim, keepdim=True) x_sum = x_masked.sum(dim=dim, keepdim=True) From 18ead5dad36746855f29cb14eafca543efd6b865 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Tue, 19 May 2026 00:49:33 +0800 Subject: [PATCH 4/7] fix(utils): simplify masked normalization std --- areal/utils/data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/areal/utils/data.py b/areal/utils/data.py index cf38633db9..a7dfe03ece 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -1579,8 +1579,7 @@ def _compute_std( else: mask = mask.to(dtype) factor = mask.sum(dim, keepdim=True) - x_safe = torch.where(mask.bool(), x, 0.0) - x_centered = (x_safe - mean) * mask + x_centered = torch.where(mask.bool(), x - mean, 0.0) x_sum_sq = (x_centered**2).sum(dim=dim, keepdim=True) if dist.is_initialized() and all_reduce: From 89eb043b9651466ea74dc60dea70b723a1512710 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Tue, 19 May 2026 00:59:56 +0800 Subject: [PATCH 5/7] fix(utils): inline masked mean sanitization --- areal/utils/data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/areal/utils/data.py b/areal/utils/data.py index a7dfe03ece..eef8117844 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -1518,8 +1518,7 @@ def _compute_mean( x_sum = x.sum(dim=dim, keepdim=True) else: mask = mask.to(dtype) - x_safe = torch.where(mask.bool(), x, 0.0) - x_masked = x_safe + x_masked = torch.where(mask.bool(), x, 0.0) factor = mask.sum(dim, keepdim=True) x_sum = x_masked.sum(dim=dim, keepdim=True) From a3890073568cbe4a746bc7e738b376a20336db37 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Tue, 19 May 2026 11:56:47 +0800 Subject: [PATCH 6/7] fix(utils): simplify masked norm centering --- areal/utils/data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/areal/utils/data.py b/areal/utils/data.py index eef8117844..23b4f53449 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -1447,8 +1447,7 @@ def __call__( # Subtract mean if loss_mask is not None: - x_safe = torch.where(loss_mask.bool(), x, 0.0) - x_centered = (x_safe - mean) * loss_mask + x_centered = torch.where(loss_mask.bool(), x - mean, 0.0) else: x_centered = x - mean From e962f36cdfab766fd29c08c335326fa1b05a2129 Mon Sep 17 00:00:00 2001 From: haoyang9804 Date: Wed, 20 May 2026 12:02:52 +0800 Subject: [PATCH 7/7] fix(utils): warn on masked non-finite norm inputs --- areal/utils/data.py | 32 ++++++++++++++++++++++ tests/test_adv_norm_config.py | 51 +++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/areal/utils/data.py b/areal/utils/data.py index 23b4f53449..bf25231de8 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -1404,6 +1404,38 @@ def __call__( bs = x.size(0) eps = self.eps + non_finite = ~torch.isfinite(x) + if non_finite.any().item(): + if loss_mask is None: + logger.warning( + "Normalization input contains non-finite values and no " + "loss_mask was provided. They will propagate through " + "normalization and may indicate an upstream numerical issue." + ) + else: + active_non_finite = loss_mask.bool().logical_and(non_finite) + masked_non_finite = (~loss_mask.bool()).logical_and(non_finite) + if active_non_finite.any().item() and masked_non_finite.any().item(): + logger.warning( + "Normalization input contains non-finite values at both " + "active and masked positions. Active non-finite values " + "will propagate through normalization; masked non-finite " + "values will be ignored by loss_mask. This may indicate " + "an upstream numerical issue." + ) + elif active_non_finite.any().item(): + logger.warning( + "Normalization input contains non-finite values at active " + "positions. They will propagate through normalization and " + "may indicate an upstream numerical issue." + ) + else: + logger.warning( + "Normalization input contains non-finite values at masked " + "positions. They will be ignored by loss_mask, but this " + "may indicate an upstream numerical issue." + ) + # Early return if no elements are active (all masked out) if loss_mask is not None and loss_mask.sum().item() == 0: return x.float() diff --git a/tests/test_adv_norm_config.py b/tests/test_adv_norm_config.py index b89c8e3712..76cf85da5a 100644 --- a/tests/test_adv_norm_config.py +++ b/tests/test_adv_norm_config.py @@ -1125,6 +1125,24 @@ def test_masked_invalid_values_do_not_poison_batch_normalization(): assert normalized[0, 1].item() == 0.0 +def test_masked_invalid_values_emit_warning(): + config = NormConfig(mean_level="batch", std_level="batch", group_size=1) + adv_norm = Normalization(config) + + advantages = torch.tensor( + [[1.0, float("nan")], [3.0, 5.0]], + dtype=torch.float32, + ) + loss_mask = torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32) + + with patch("areal.utils.data.logger.warning") as warning: + normalized = adv_norm(advantages, loss_mask) + + warning.assert_called_once() + assert "non-finite values at masked positions" in warning.call_args.args[0] + assert torch.isfinite(normalized).all() + + def test_masked_invalid_values_do_not_poison_group_normalization(): config = NormConfig(mean_level="group", std_level="group", group_size=2) adv_norm = Normalization(config) @@ -1146,6 +1164,39 @@ def test_masked_invalid_values_do_not_poison_group_normalization(): assert normalized[0, 1].item() == 0.0 +def test_unmasked_invalid_values_are_not_sanitized(): + config = NormConfig(mean_level="batch", std_level=None, group_size=1) + adv_norm = Normalization(config) + + advantages = torch.tensor( + [[1.0, float("nan")], [3.0, 5.0]], + dtype=torch.float32, + ) + loss_mask = torch.ones_like(advantages) + + normalized = adv_norm(advantages, loss_mask) + + assert torch.isnan(normalized).any() + + +def test_unmasked_invalid_values_emit_warning_and_are_not_sanitized(): + config = NormConfig(mean_level="batch", std_level=None, group_size=1) + adv_norm = Normalization(config) + + advantages = torch.tensor( + [[1.0, float("nan")], [3.0, 5.0]], + dtype=torch.float32, + ) + loss_mask = torch.ones_like(advantages) + + with patch("areal.utils.data.logger.warning") as warning: + normalized = adv_norm(advantages, loss_mask) + + warning.assert_called_once() + assert "non-finite values at active positions" in warning.call_args.args[0] + assert torch.isnan(normalized).any() + + def test_non_trivial_loss_mask_leave_one_out(): """Test leave-one-out normalization with non-trivial loss mask and verify expected values.""" config = NormConfig(