diff --git a/test/metrics_st.py b/test/metrics_st.py index eda2acc..ddfda73 100644 --- a/test/metrics_st.py +++ b/test/metrics_st.py @@ -26,10 +26,10 @@ def compute_depth_metrics(gt, pred, mask=None, median_align=False): a3 = (thresh < 1.25 ** 3).float().mean() rmse = (gt - pred) ** 2 - rmse = torch.sqrt(rmse).mean() + rmse = torch.sqrt(rmse.mean()) rmse_log = (torch.log10(gt) - torch.log10(pred)) ** 2 - rmse_log = torch.sqrt(rmse_log).mean() + rmse_log = torch.sqrt(rmse_log.mean()) abs_ = torch.mean(torch.abs(gt - pred))