Skip to content

Possible numerical error in log-norm computation #13

@maxwellzh

Description

@maxwellzh

In current implementation, emissions and the predictions subtract their own maximum values respectively. But consider this case

emission[0, 0] = [0, -1000]
prediction[0, 0] = [-1000, 0]
->
# current impl
logNorm[0, 0, 0] = log(exp(emission[0, 0]-maxEs) @ exp(prediction[0, 0]-maxPs)) + maxEs + maxPs
                             = log(exp([0, -1000]) @ exp([-1000, 0]))
                             = log([1, exp(-1000)] @ [exp(-1000), 1])  <-- exp(-1000) would give 0 in FP32 precision
                             = log(0)
                             = -inf

# correct result
logNorm[0, 0, 0] = log(2) - 1000

I also tried convert emission and prediction into FP64 before calculating the logNorm, but it still didn't work in my asr experiment.

The broadcast-sum way is more numerical stable, but would consume O(B*T*U*V) memory.

logNorm = torch.log_softmax(emission.unsqueeze(2) + prediction.unsqueeze(1), dim=-1)

maxEs = emissions.max(dim=2, keepdim=True)[0]
maxPs = predictions.max(dim=2, keepdim=True)[0]
log_norms = torch.log(torch.bmm(
torch.exp(emissions - maxEs),
torch.exp((predictions - maxPs)).transpose(1, 2)))
log_norms = log_norms + maxEs + maxPs.transpose(1, 2)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions