Problem
The interval Tobit loss computes the log-probability of a value falling within [lower, upper] as:
log_prob = torch.log(
torch.clamp(torch.exp(log_p_upper) - torch.exp(log_p_lower), min=1e-12)
)
When both log_p_upper and log_p_lower are very negative (i.e. the model predicts far outside the interval), exp(·) underflows to 0 in float32,
producing log(1e-12) ≈ -27.6 — a large artificial loss that generates extreme gradients. This can destabilize training for out-of-range
predictions.
Proposed Fix
Use numerically stable log-space subtraction:
# log(exp(a) - exp(b)) = a + log(1 - exp(b - a)), for a > b
log_diff = log_p_upper + torch.log1p(
-torch.exp((log_p_lower - log_p_upper).clamp(max=-1e-7))
)
log_prob = log_diff.clamp(min=-100.0)
torch.log1p is numerically stable for arguments near 0 (which exp(b - a) approaches as the interval probability shrinks). Clamping log_prob at
−100 prevents runaway gradients for extreme predictions while maintaining a smooth loss surface.
Files
- moal/model.py (_tobit_loss or equivalent)
Notes
This is a correctness fix, not a feature — it should be included in the next patch release. Add a unit test with a case where both log_p_upper <
-10 and log_p_lower < -10 to catch regressions.
Problem
The interval Tobit loss computes the log-probability of a value falling within
[lower, upper]as:When both log_p_upper and log_p_lower are very negative (i.e. the model predicts far outside the interval), exp(·) underflows to 0 in float32,
producing log(1e-12) ≈ -27.6 — a large artificial loss that generates extreme gradients. This can destabilize training for out-of-range
predictions.
Proposed Fix
Use numerically stable log-space subtraction:
torch.log1p is numerically stable for arguments near 0 (which exp(b - a) approaches as the interval probability shrinks). Clamping log_prob at
−100 prevents runaway gradients for extreme predictions while maintaining a smooth loss surface.
Files
Notes
This is a correctness fix, not a feature — it should be included in the next patch release. Add a unit test with a case where both log_p_upper <
-10 and log_p_lower < -10 to catch regressions.