diff --git a/src/mrpro/operators/functionals/SSIM.py b/src/mrpro/operators/functionals/SSIM.py index 9947bde60..1300c6ccf 100644 --- a/src/mrpro/operators/functionals/SSIM.py +++ b/src/mrpro/operators/functionals/SSIM.py @@ -246,8 +246,8 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]: See this PyTorch `discussion `_. """ ssim = ssim3d( - self.target.real, - x.real, + self.target, + x, weight=self.weight, k1=self.k1, k2=self.k2, diff --git a/tests/operators/functionals/test_ssim.py b/tests/operators/functionals/test_ssim.py index e701c2197..4b156af27 100644 --- a/tests/operators/functionals/test_ssim.py +++ b/tests/operators/functionals/test_ssim.py @@ -46,3 +46,16 @@ def test_ssim_reduction() -> None: assert ssim_volume.shape == (2, 3) assert ssim_full.shape == () assert ssim_none.shape == (2, 3, 4, 4, 4) + + +def test_ssim_complex() -> None: + """Test the SSIM functional for complex-valued tensors.""" + rng = RandomGenerator(0) + target_real = rng.float32_tensor((1, 10, 10), low=0.0, high=1.0) + target_imag = rng.float32_tensor((1, 10, 10), low=0.0, high=1.0) + target = target_real + 1j * target_imag + test = rng.complex64_tensor((1, 10, 10), low=0.0, high=1.0) + + torch.testing.assert_close( + SSIM(target)(test)[0], 0.5 * (SSIM(target.real)(test.real)[0] + (SSIM(target.imag)(test.imag)[0])) + )