Skip to content

CEWeighted loss issue #91

@Escape142

Description

@Escape142

🐛 Bug

CEWeighted currently computes weighted loss as (loss * sample_weight).mean(), which is not a normalized weighted mean.
This divides by N instead of sum(sample_weight), so loss scale depends on average batch weight and can be unstable for imbalanced distributions.

To Reproduce

Steps to reproduce the behavior:

  1. Open replay/nn/loss/ce.py and find CEWeighted.forward.
  2. See current aggregation:
    loss = (loss * sample_weight).mean()
  3. Run a toy check:
import torch
l = torch.tensor([1.0, 3.0])
w = torch.tensor([0.1, 10.0])
current = (l * w).mean()
expected = (l * w).sum() / w.sum()
print(current.item(), expected.item())

Expected behavior

Use weighted mean: (loss * sample_weight).sum() / sample_weight.sum()
(with denominator safety, e.g. clamp_min(eps)).

Additional context

Same pattern appears in:

  • replay/nn/loss/ce.py (CEWeighted, CESampledWeighted)
  • replay/nn/loss/logout_ce.py (LogOutCEWeighted)

Checklist

  • bug description
  • steps to reproduce
  • expected behavior
  • code sample / screenshots

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions