I want to investigate whether we should compute metrics only when doing `loss.backward()` using [Lightning features](https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html).