In current implementation, emissions and the predictions subtract their own maximum values respectively. But consider this case
emission[0, 0] = [0, -1000]
prediction[0, 0] = [-1000, 0]
->
# current impl
logNorm[0, 0, 0] = log(exp(emission[0, 0]-maxEs) @ exp(prediction[0, 0]-maxPs)) + maxEs + maxPs
= log(exp([0, -1000]) @ exp([-1000, 0]))
= log([1, exp(-1000)] @ [exp(-1000), 1]) <-- exp(-1000) would give 0 in FP32 precision
= log(0)
= -inf
# correct result
logNorm[0, 0, 0] = log(2) - 1000
I also tried convert emission and prediction into FP64 before calculating the logNorm, but it still didn't work in my asr experiment.
The broadcast-sum way is more numerical stable, but would consume O(B*T*U*V) memory.
logNorm = torch.log_softmax(emission.unsqueeze(2) + prediction.unsqueeze(1), dim=-1)
|
maxEs = emissions.max(dim=2, keepdim=True)[0] |
|
maxPs = predictions.max(dim=2, keepdim=True)[0] |
|
log_norms = torch.log(torch.bmm( |
|
torch.exp(emissions - maxEs), |
|
torch.exp((predictions - maxPs)).transpose(1, 2))) |
|
log_norms = log_norms + maxEs + maxPs.transpose(1, 2) |
In current implementation, emissions and the predictions subtract their own maximum values respectively. But consider this case
I also tried convert emission and prediction into FP64 before calculating the logNorm, but it still didn't work in my asr experiment.
The broadcast-sum way is more numerical stable, but would consume
O(B*T*U*V)memory.transducer/transducer/torch_binding.py
Lines 162 to 167 in e90c6f4