🐛 Bug
CEWeighted currently computes weighted loss as (loss * sample_weight).mean(), which is not a normalized weighted mean.
This divides by N instead of sum(sample_weight), so loss scale depends on average batch weight and can be unstable for imbalanced distributions.
To Reproduce
Steps to reproduce the behavior:
- Open replay/nn/loss/ce.py and find CEWeighted.forward.
- See current aggregation:
loss = (loss * sample_weight).mean()
- Run a toy check:
import torch
l = torch.tensor([1.0, 3.0])
w = torch.tensor([0.1, 10.0])
current = (l * w).mean()
expected = (l * w).sum() / w.sum()
print(current.item(), expected.item())
Expected behavior
Use weighted mean: (loss * sample_weight).sum() / sample_weight.sum()
(with denominator safety, e.g. clamp_min(eps)).
Additional context
Same pattern appears in:
- replay/nn/loss/ce.py (CEWeighted, CESampledWeighted)
- replay/nn/loss/logout_ce.py (LogOutCEWeighted)
Checklist
🐛 Bug
CEWeightedcurrently computes weighted loss as (loss * sample_weight).mean(), which is not a normalized weighted mean.This divides by N instead of sum(sample_weight), so loss scale depends on average batch weight and can be unstable for imbalanced distributions.
To Reproduce
Steps to reproduce the behavior:
loss = (loss * sample_weight).mean()Expected behavior
Use weighted mean: (loss * sample_weight).sum() / sample_weight.sum()
(with denominator safety, e.g. clamp_min(eps)).
Additional context
Same pattern appears in:
Checklist