diff --git a/src/mrpro/operators/NonUniformFastFourierOp.py b/src/mrpro/operators/NonUniformFastFourierOp.py index 47ce69e32..9faccbaa8 100644 --- a/src/mrpro/operators/NonUniformFastFourierOp.py +++ b/src/mrpro/operators/NonUniformFastFourierOp.py @@ -6,15 +6,18 @@ from types import EllipsisType from typing import Literal +import einops import numpy as np import torch from pytorch_finufft.functional import finufft_type1, finufft_type2 from typing_extensions import Self +from mrpro.data.DcfData import DcfData from mrpro.data.KTrajectory import KTrajectory from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.operators.FastFourierOp import FastFourierOp from mrpro.operators.LinearOperator import LinearOperator +from mrpro.operators.PCACompressionOp import PCACompressionOp +from mrpro.utils import normalize_index, unsqueeze_right class NonUniformFastFourierOp(LinearOperator, adjoint_as_backward=True): @@ -200,12 +203,14 @@ def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor,]: Parameters ---------- x - Coil image data, typically with shape `(..., coils, z, y, x)`. + Image-space data with selected spatial dimensions as trailing axes, e.g. + `(..., coils, z, y, x)` for 3D or `(..., coils, y, x)` for 2D. Returns ------- - Coil k-space data at non-uniform locations, - with shape `(..., coils, k2, k1, k0)`. + K-space data at the non-uniform trajectory locations with the same leading + dimensions as the input and trailing sampled dimensions, e.g. + `(..., coils, k2, k1, k0)` for 3D or `(..., coils, k1, k0)` for 2D. """ return super().__call__(x) @@ -282,10 +287,35 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: x = x.permute(*unpermute_zyx) return (x,) + def toeplitz( + self, + weight: torch.Tensor | DcfData | None = None, + subspace: torch.Tensor | PCACompressionOp | None = None, + subspace_dim: int = 0, + ) -> LinearOperator: + """Return the Toeplitz Gram operator. + + Parameters + ---------- + weight + Optional density compensation or k-space weights. If provided, calculates + the normal operator corresponding to :math:`F^H W F`. + subspace + Optional temporal subspace basis of shape `(n_timepoints, n_coefficients)`, or + a `PCACompressionOp` whose adjoint defines the temporal expansion basis. + subspace_dim + Dimension of the coefficient channel in the input/output tensors when `subspace` is provided. + """ + if subspace is None: + return NonUniformFastFourierOpGramOp(self, weight=weight) + return SubspaceNonUniformFastFourierOpGramOp( + self, subspace_basis=subspace, weight=weight, subspace_dim=subspace_dim + ) + @property def gram(self) -> LinearOperator: - """Return the gram operator.""" - return NonUniformFastFourierOpGramOp(self) + """Return the unweighted Toeplitz Gram operator.""" + return self.toeplitz() def __repr__(self) -> str: """Representation method for NUFFT operator.""" @@ -356,17 +386,17 @@ def gram_nufft_kernel( else: # second half in the dimension idx.append(slice(kernel_part.size(dim) + 1, None)) kernel_part = kernel_part.index_select( - dim, torch.arange(kernel_part.size(dim) - 1, 0, -1, device=kernel.device) + dim, + torch.arange(kernel_part.size(dim) - 1, 0, -1, device=kernel.device), ) # flip kernel[tuple(idx)] = kernel_part kernel = symmetrize(kernel, rank) kernel = torch.fft.hfftn(kernel, dim=list(range(-rank, 0)), norm='backward') - kernel = torch.fft.fftshift(kernel, dim=list(range(-rank, 0))) return kernel -class NonUniformFastFourierOpGramOp(LinearOperator): +class NonUniformFastFourierOpGramOp(LinearOperator, adjoint_as_backward=True): """Gram operator for `NonUniformFastFourierOp`. Implements the adjoint of the forward operator of the non-uniform Fast Fourier operator, i.e. the gram operator @@ -377,51 +407,91 @@ class NonUniformFastFourierOpGramOp(LinearOperator): This should not be used directly, but rather through the `~NonUniformFastFourierOp.gram` method of a `NonUniformFastFourierOp` object. + + .. note:: + Consider calling .half() on the operator to save memory at the cost of precision. """ _kernel: torch.Tensor | None - def __init__(self, nufft_op: NonUniformFastFourierOp) -> None: + def __init__(self, nufft_op: NonUniformFastFourierOp, weight: torch.Tensor | DcfData | None = None) -> None: """Initialize the gram operator. Parameters ---------- nufft_op The py:class:`NonUniformFastFourierOp` to calculate the gram operator for. - + weight + Optional density compensation weights. If provided, calculates F^H W F. """ super().__init__() - self.nufft_gram: None | LinearOperator = None - - if not nufft_op._dimension_210: + if not nufft_op._direction_zyx: + self._kernel = None return - weight = torch.ones( - [*nufft_op._traj_broadcast_shape[:-4], 1, *nufft_op._traj_broadcast_shape[-3:]], - ).to(nufft_op._omega) + self._dim = nufft_op._direction_zyx + self._recon_shape = nufft_op._im_size - # We rearrange weight into (sep_dims, joint_dims, nufft_dims) + if isinstance(weight, DcfData): + weight = weight.data + weight_shape = [*nufft_op._traj_broadcast_shape[:-4], 1, *nufft_op._traj_broadcast_shape[-3:]] + if weight is None: + weight = torch.ones(weight_shape, device=nufft_op._omega.device, dtype=nufft_op._omega.dtype.to_real()) + else: + weight = weight.to(nufft_op._omega.device).broadcast_to(weight_shape) + # We rearrange weight into (sep_dims, joint_dims, nufft_directions) _, permute_zyx, sep_dims_210, permute_210 = nufft_op._separate_joint_dimensions(weight.ndim) unpermute_zyx = torch.tensor(permute_zyx).argsort().tolist() - weight = weight.permute(*permute_210) unflatten_other_shape = weight.shape[: -len(nufft_op._dimension_210) - 1] # -1 for coil - # combine sep_dims - weight = weight.flatten(end_dim=len(sep_dims_210) - 1) if len(sep_dims_210) else weight[None, :] + if sep_dims_210: # combine sep_dims + weight = weight.flatten(end_dim=len(sep_dims_210) - 1) + else: + weight = weight[None, :] # combine joint_dims and nufft_dims weight = weight.flatten(start_dim=1, end_dim=-len(nufft_op._dimension_210) - 1).flatten(start_dim=2) kernel = gram_nufft_kernel(weight, nufft_op._omega, nufft_op._im_size) kernel = kernel.reshape(*unflatten_other_shape, -1, *kernel.shape[-len(nufft_op._direction_zyx) :]) - kernel = kernel.permute(*unpermute_zyx) - kernel = kernel * (nufft_op.scale) ** 2 + self._kernel = kernel.permute(*unpermute_zyx) * nufft_op.scale.item() ** 2 - fft = FastFourierOp( - dim=nufft_op._direction_zyx, - encoding_matrix=[2 * s for s in nufft_op._im_size], - recon_matrix=nufft_op._im_size, - ) - self.nufft_gram = fft.H * kernel @ fft + @property + def recon_matrix(self) -> SpatialDimension[int]: + """Expected reconstructed image shape as a spatial dimension.""" + zyx = [1, 1, 1] + for direction, size in zip(self._dim, self._recon_shape, strict=True): + zyx[direction] = size + return SpatialDimension(*zyx) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply forward of NonUniformFastFourierOpGramOp. + + .. note:: + Prefer calling the instance of the NonUniformFastFourierOpGramOp operator as ``operator(x)`` over + directly calling this method. See this PyTorch `discussion `_. + """ + # We do the fft and cropping here directly, as it faster and more memory efficient + # then using the operators. As this will be a bottleneck in iterative 3D reconstructions, + # it is worth the specialization. + + # This function on its own is also not autograd save (in-place ops), but we rely on the + # adjoint-as-backward trick for linear operators to make it work + + if self._kernel is None: + return (x,) + + padded_shape = [2 * s for s in self._recon_shape] + spatial_crop: list[slice | EllipsisType] = [..., slice(None), slice(None), slice(None)] + for d, s in zip(self._dim, self._recon_shape, strict=True): + spatial_crop[d] = slice(0, s) + + x = torch.fft.fftn(x, s=padded_shape, dim=self._dim) + x.mul_(self._kernel) + x = torch.fft.ifftn(x, dim=self._dim, out=x) + x = x[tuple(spatial_crop)] + out = x.clone() # clone to deallocate the larger x on exit + + return (out,) def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor,]: """Apply the Gram operator of the NonUniformFastFourierOp (NUFFT.H @ NUFFT). @@ -437,36 +507,210 @@ def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor,]: Returns ------- - Output tensor, image-space data after NUFFT.H @ NUFFT has been applied. + Image-space data after applying `NUFFT.H @ NUFFT`, with the same shape as `x`. """ return super().__call__(x) + def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the adjoint of the Gram operator. + + Since the Gram operator (NUFFT.H @ NUFFT) is self-adjoint, + this method calls the forward operation. + + Parameters + ---------- + x + Input tensor, typically image-space data with shape `(..., coils, z, y, x)`. + + Returns + ------- + Image-space data after applying the adjoint Gram operator, with the same shape as `x`. + """ + return super().__call__(x) + + @property + def H(self) -> Self: # noqa: N802 + """Adjoint operator of the gram operator.""" + return self + + +class SubspaceNonUniformFastFourierOpGramOp(LinearOperator, adjoint_as_backward=True): + """Subspace Gram operator for `NonUniformFastFourierOp`. + + This operator acts on temporally compressed coefficient images and applies the + compressed normal operator corresponding to a time-varying non-uniform Fourier + encoding. The Toeplitz kernel is precomputed once during initialization. + + The current implementation is intentionally restricted to the common MRF case: + exactly one varying separate trajectory dimension, interpreted as time, and no + additional non-singleton joint trajectory dimensions. + """ + + _kernel: torch.Tensor | None + + def __init__( + self, + nufft_op: NonUniformFastFourierOp, + subspace_basis: torch.Tensor | PCACompressionOp, + weight: torch.Tensor | DcfData | None = None, + subspace_dim: int = 0, + ) -> None: + """Initialize the subspace Gram operator. + + Parameters + ---------- + nufft_op + The non-uniform Fourier operator defining the trajectory and spatial geometry. + subspace_basis + Temporal subspace basis of shape `(n_timepoints, n_coefficients)`, or + a `PCACompressionOp` whose adjoint defines the temporal expansion basis. + weight + Optional density compensation or k-space weights. If provided, calculates + the compressed normal operator corresponding to :math:`F^H W F`. + subspace_dim + Dimension of the coefficient channel in the input/output tensors. + """ + super().__init__() + if not nufft_op._direction_zyx: + self._kernel = None + return + + self._dim = nufft_op._direction_zyx + self._recon_shape = nufft_op._im_size + self._subspace_dim = subspace_dim + + if isinstance(subspace_basis, PCACompressionOp): + basis = subspace_basis.compression_matrix.mH + else: + basis = subspace_basis + if basis.ndim > 2: + basis = basis.squeeze(tuple(range(basis.ndim - 2))) + if basis.ndim > 2: + raise ValueError(f'Basis cannot contain non-singleton batch dimensions; got shape {basis.shape}.') + if basis.ndim == 1: # rank-1 special case: only one coefficient + basis = basis.unsqueeze(-1) + basis = basis.to(device=nufft_op._omega.device) + + if isinstance(weight, DcfData): + weight = weight.data + weight_shape = [*nufft_op._traj_broadcast_shape[:-4], 1, *nufft_op._traj_broadcast_shape[-3:]] + if weight is None: + weight = torch.ones(weight_shape, device=nufft_op._omega.device, dtype=nufft_op._omega.dtype.to_real()) + else: + weight = weight.to(nufft_op._omega.device).broadcast_to(weight_shape) + _, _permute_zyx, sep_dims_210, permute_210 = nufft_op._separate_joint_dimensions(weight.ndim) + weight = weight.permute(*permute_210) + if sep_dims_210: + weight = weight.flatten(end_dim=len(sep_dims_210) - 1) + else: + weight = weight[None, :] + weight = weight.flatten(start_dim=1, end_dim=-len(nufft_op._dimension_210) - 1).flatten(start_dim=2) + + omega = nufft_op._omega + + if len(sep_dims_210) > 1: + raise NotImplementedError( + 'SubspaceNonUniformFastFourierOpGramOp currently only supports a single varying separate dimension.' + ) + if basis.shape[0] != (n_timepoints := weight.shape[0]): + raise ValueError( + f'subspace_basis has {basis.shape[0]} time points, but the trajectory varies over {n_timepoints}.' + ) + + subspace_kernel: torch.Tensor | None = None + for basis_row, weight_row, omega_row in zip(basis, weight, omega, strict=True): + current = gram_nufft_kernel(weight_row[None], omega_row[None], nufft_op._im_size)[0, 0] + coeff_outer = basis_row.conj()[:, None] * basis_row[None, :] + term = unsqueeze_right(coeff_outer, current.ndim) * current + subspace_kernel = term if subspace_kernel is None else subspace_kernel + term + + assert subspace_kernel is not None # noqa: S101 + kernel_shape = [1, 1, 1] + for direction, size in zip(self._dim, subspace_kernel.shape[2:], strict=True): + kernel_shape[direction] = size + subspace_kernel = subspace_kernel.reshape(*subspace_kernel.shape[:2], *kernel_shape) + self._kernel = subspace_kernel * nufft_op.scale.item() ** 2 + + @property + def n_coefficients(self) -> int: + """Number of compressed coefficient channels.""" + if self._kernel is None: + raise RuntimeError('n_coefficients is undefined if there are no NUFFT axes.') + return self._kernel.shape[0] + + @property + def recon_matrix(self) -> SpatialDimension[int]: + """Expected reconstructed image shape as a spatial dimension.""" + zyx = [1, 1, 1] + for direction, size in zip(self._dim, self._recon_shape, strict=True): + zyx[direction] = size + return SpatialDimension(*zyx) + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]: - """Apply forward of NonUniformFastFourierOpGramOp. + """Apply forward of SubspaceNonUniformFastFourierOpGramOp. .. note:: - Prefer calling the instance of the NonUniformFastFourierOpGramOp operator as ``operator(x)`` over + Prefer calling the instance of the SubspaceNonUniformFastFourierOpGramOp operator as ``operator(x)`` over directly calling this method. See this PyTorch `discussion `_. """ - if self.nufft_gram is not None: - (x,) = self.nufft_gram(x) + # We do the fft, coefficient mixing, and cropping here directly, as it is faster and more memory efficient + # than using separate operators. + + # This function on its own is also not autograd save (in-place ops), but we rely on the + # adjoint-as-backward trick for linear operators to make it work. + if self._kernel is None: + return (x,) + subspace_dim = normalize_index(x.ndim, self._subspace_dim) + if x.shape[subspace_dim] != self.n_coefficients: + raise ValueError( + f'Input has {x.shape[self._subspace_dim]} coefficients along subspace_dim={self._subspace_dim}, ' + f'expected {self.n_coefficients}.' + ) + x = x.movedim(subspace_dim, 0) - return (x,) + padded_shape = [2 * s for s in self._recon_shape] + spatial_crop: list[slice | EllipsisType] = [..., slice(None), slice(None), slice(None)] + for d, s in zip(self._dim, self._recon_shape, strict=True): + spatial_crop[d] = slice(0, s) + + x = torch.fft.fftn(x, s=padded_shape, dim=self._dim) + x = einops.einsum(self._kernel.to(x.dtype), x, 'coeff_out coeff_in ..., coeff_in ... -> coeff_out ...') + x = torch.fft.ifftn(x, dim=self._dim) + x = x[tuple(spatial_crop)] + out = x.clone().movedim(0, subspace_dim) + return (out,) + + def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply the compressed Toeplitz Gram operator. + + Parameters + ---------- + x + Subspace coefficient images. The coefficient axis is given by `subspace_dim` + + Returns + ------- + Subspace coefficient images after applying the compressed Gram operator, + with the same shape as `x`. + """ + return super().__call__(x) def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]: - """Apply the adjoint of the Gram operator. + """Apply the adjoint of the compressed Toeplitz Gram operator. - Since the Gram operator (NUFFT.H @ NUFFT) is self-adjoint, - this method calls the forward operation. + Since the compressed Gram operator is self-adjoint, this method calls the + forward operation. Parameters ---------- x - Input tensor, typically image-space data with shape `(..., coils, z, y, x)`. + Subspace coefficient images. The coefficient axis is given by `subspace_dim` + Returns ------- - Output tensor, same shape as the input. + Subspace coefficient images after applying the adjoint compressed Gram operator, + with the same shape as `x`. """ return super().__call__(x) diff --git a/src/mrpro/operators/PCACompressionOp.py b/src/mrpro/operators/PCACompressionOp.py index c9ad87363..556a25164 100644 --- a/src/mrpro/operators/PCACompressionOp.py +++ b/src/mrpro/operators/PCACompressionOp.py @@ -108,3 +108,8 @@ def adjoint(self, data: torch.Tensor) -> tuple[torch.Tensor,]: f'cannot be multiplied with Data {tuple(data.shape)}.' ) from e return (result,) + + @property + def compression_matrix(self) -> torch.Tensor: + """Get the compression matrix.""" + return self._compression_matrix diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index b8e64333b..6bf2addfa 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -33,7 +33,9 @@ from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix from mrpro.operators.MagnitudeOp import MagnitudeOp from mrpro.operators.MultiIdentityOp import MultiIdentityOp -from mrpro.operators.NonUniformFastFourierOp import NonUniformFastFourierOp +from mrpro.operators.NonUniformFastFourierOp import ( + NonUniformFastFourierOp, +) from mrpro.operators.OptimizerOp import OptimizerOp from mrpro.operators.PatchOp import PatchOp from mrpro.operators.PCACompressionOp import PCACompressionOp diff --git a/tests/operators/test_non_uniform_fast_fourier_op.py b/tests/operators/test_non_uniform_fast_fourier_op.py index 5a29bb6f0..a06fe16d8 100644 --- a/tests/operators/test_non_uniform_fast_fourier_op.py +++ b/tests/operators/test_non_uniform_fast_fourier_op.py @@ -1,12 +1,14 @@ """Tests for Non-Uniform Fast Fourier operator.""" +import einops import pytest import torch -from mrpro.data import KData, KTrajectory -from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.data import DcfData, KData, KTrajectory, SpatialDimension from mrpro.data.traj_calculators import KTrajectoryIsmrmrd -from mrpro.operators import FastFourierOp, NonUniformFastFourierOp +from mrpro.operators import DensityCompensationOp, FastFourierOp, NonUniformFastFourierOp, PCACompressionOp +from mrpro.operators.NonUniformFastFourierOp import SubspaceNonUniformFastFourierOpGramOp from mrpro.utils import RandomGenerator +from torch.autograd.gradcheck import gradcheck from tests.conftest import COMMON_MR_TRAJECTORIES, create_traj from tests.helper import dotproduct_adjointness_test, relative_image_difference @@ -20,6 +22,33 @@ def create_data(img_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) -> tuple[to return img, trajectory +def create_time_varying_2d_nufft_op( + n_timepoints: int = 4, + image_shape: tuple[int, int] = (12, 10), +) -> NonUniformFastFourierOp: + """Create a small time-varying 2D NUFFT over y/x.""" + nk = (n_timepoints, 1, 1, *image_shape) + trajectory = create_traj( + nkx=nk, + nky=nk, + nkz=(n_timepoints, 1, 1, 1, 1), + type_kx='non-uniform', + type_ky='non-uniform', + type_kz='zero', + ) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) + return NonUniformFastFourierOp( + direction=(-2, -1), + recon_matrix=SpatialDimension(z=1, y=image_shape[0], x=image_shape[1]), + encoding_matrix=encoding_matrix, + traj=trajectory, + ) + + @COMMON_MR_TRAJECTORIES def test_non_uniform_fast_fourier_op_fwd_adj_property( img_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2 @@ -239,6 +268,195 @@ def test_non_uniform_fast_fourier_op_repr(): assert 'device' in repr_str +def test_subspace_non_uniform_fast_fourier_op_gram() -> None: + """Test subspace Toeplitz Gram against explicit expand-apply-compress reference.""" + rng = RandomGenerator(seed=1) + n_timepoints, n_coefficients = 4, 2 + image_shape = (12, 10) + + nufft_op = create_time_varying_2d_nufft_op(n_timepoints=n_timepoints, image_shape=image_shape) + basis = rng.complex64_tensor((n_timepoints, n_coefficients)) + alpha = rng.complex64_tensor((n_coefficients, 1, 1, *image_shape)) + + subspace_gram = nufft_op.toeplitz(subspace=basis) + assert isinstance(subspace_gram, SubspaceNonUniformFastFourierOpGramOp) + + expanded = einops.einsum(basis, alpha, 'time coeff, coeff joint coil ... -> time joint coil ...') + (kspace,) = nufft_op(expanded) + (backprojected,) = nufft_op.H(kspace) + expected = einops.einsum(basis.conj(), backprojected, 'time coeff, time joint coil ... -> coeff joint coil ...') + (actual,) = subspace_gram(alpha) + + torch.testing.assert_close(actual, expected, rtol=2e-3, atol=2e-3) + + +def test_subspace_non_uniform_fast_fourier_op_gram_non_contiguous_directions() -> None: + """Test subspace Toeplitz Gram over non-contiguous z/x image axes.""" + rng = RandomGenerator(seed=7) + n_coefficients = 2 + img_shape = (8, 5, 64, 1, 64) + nkx = (8, 1, 1, 18, 128) + nky = (8, 1, 1, 1, 1) + nkz = (8, 1, 1, 18, 128) + trajectory = create_traj(nkx, nky, nkz, type_kx='non-uniform', type_ky='zero', type_kz='non-uniform') + + recon_matrix = SpatialDimension(img_shape[-3], img_shape[-2], img_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) + direction = [d for d, e in zip(('z', 'y', 'x'), encoding_matrix.zyx, strict=False) if e > 1] + nufft_op = NonUniformFastFourierOp( + direction=direction, # type: ignore[arg-type] + recon_matrix=recon_matrix, + encoding_matrix=encoding_matrix, + traj=trajectory, + ) + basis = rng.complex64_tensor((img_shape[0], n_coefficients)) + alpha = rng.complex64_tensor((n_coefficients, *img_shape[1:])) + + subspace_gram = nufft_op.toeplitz(subspace=basis) + + expanded = einops.einsum(basis, alpha, 'time coeff, coeff coil ... -> time coil ...') + (kspace,) = nufft_op(expanded) + (backprojected,) = nufft_op.H(kspace) + expected = einops.einsum(basis.conj(), backprojected, 'time coeff, time coil ... -> coeff coil ...') + (actual,) = subspace_gram(alpha) + + torch.testing.assert_close(actual, expected, rtol=2e-3, atol=2e-3) + + +@pytest.mark.parametrize('n_coefficients', [1, 2]) +def test_subspace_non_uniform_fast_fourier_op_gram_single_timepoint_basis(n_coefficients: int) -> None: + """Test single-timepoint bases keep their time and coefficient axes.""" + rng = RandomGenerator(seed=8) + n_timepoints = 1 + image_shape = (12, 10) + + nufft_op = create_time_varying_2d_nufft_op(n_timepoints=n_timepoints, image_shape=image_shape) + basis = rng.complex64_tensor((n_timepoints, n_coefficients)) + alpha = rng.complex64_tensor((n_coefficients, 1, 1, *image_shape)) + + expanded = einops.einsum(basis, alpha, 'time coeff, coeff joint coil ... -> time joint coil ...') + (kspace,) = nufft_op(expanded) + (backprojected,) = nufft_op.H(kspace) + expected = einops.einsum(basis.conj(), backprojected, 'time coeff, time joint coil ... -> coeff joint coil ...') + (actual,) = nufft_op.toeplitz(subspace=basis)(alpha) + + torch.testing.assert_close(actual, expected, rtol=2e-3, atol=2e-3) + + +def test_subspace_non_uniform_fast_fourier_op_gram_accepts_pca_operator() -> None: + """Test subspace Gram accepts a PCACompressionOp as basis input.""" + rng = RandomGenerator(seed=2) + n_timepoints, n_coefficients = 4, 2 + image_shape = (12, 10) + + nufft_op = create_time_varying_2d_nufft_op(n_timepoints=n_timepoints, image_shape=image_shape) + training_signals = rng.complex64_tensor((1, 32, n_timepoints)) + pca_op = PCACompressionOp(training_signals, n_components=n_coefficients, centering=False) + basis = pca_op.compression_matrix.squeeze(0).mH + alpha = rng.complex64_tensor((n_coefficients, 1, 1, *image_shape)) + + subspace_gram_from_pca = SubspaceNonUniformFastFourierOpGramOp(nufft_op, pca_op) + subspace_gram_from_basis = SubspaceNonUniformFastFourierOpGramOp(nufft_op, basis) + + (actual_from_pca,) = subspace_gram_from_pca(alpha) + (actual_from_basis,) = subspace_gram_from_basis(alpha) + + torch.testing.assert_close(actual_from_pca, actual_from_basis) + + +def test_non_uniform_fast_fourier_op_weighted_toeplitz_matches_explicit_dcf_normal_operator() -> None: + """Test Toeplitz(weight=dcf) matches the explicit weighted normal operator F^H DCF F.""" + rng = RandomGenerator(seed=6) + n_timepoints = 4 + image_shape = (12, 10) + + nufft_op = create_time_varying_2d_nufft_op(n_timepoints=n_timepoints, image_shape=image_shape) + image = rng.complex64_tensor((n_timepoints, 1, 1, *image_shape)) + dcf = DcfData(data=rng.float32_tensor((n_timepoints, 1, 1, 1, image_shape[-1]))) + + explicit_operator = nufft_op.H @ DensityCompensationOp(dcf) @ nufft_op + weighted_toeplitz = nufft_op.toeplitz(weight=dcf) + + (expected,) = explicit_operator(image) + (actual,) = weighted_toeplitz(image) + + torch.testing.assert_close(actual, expected, rtol=2e-3, atol=2e-3) + + +def test_non_uniform_fast_fourier_op_gram_autograd() -> None: + """Test autograd of the Toeplitz Gram operator.""" + rng = RandomGenerator(seed=3) + n_timepoints = 2 + image_shape = (5, 4) + nufft_op = create_time_varying_2d_nufft_op( + n_timepoints=n_timepoints, + image_shape=image_shape, + ) + operator = nufft_op.toeplitz().double() + image = rng.complex128_tensor((n_timepoints, 1, 1, *image_shape)).requires_grad_(True) + + gradcheck(operator, (image,), fast_mode=True) + + +def test_subspace_non_uniform_fast_fourier_op_gram_autograd() -> None: + """Test autograd of the subspace Toeplitz Gram operator.""" + rng = RandomGenerator(seed=4) + n_timepoints, n_coefficients = 3, 2 + image_shape = (5, 4) + nufft_op = create_time_varying_2d_nufft_op( + n_timepoints=n_timepoints, + image_shape=image_shape, + ) + basis = rng.complex128_tensor((n_timepoints, n_coefficients)) + operator = nufft_op.toeplitz(subspace=basis).double() + coefficients = rng.complex128_tensor((n_coefficients, 1, 1, *image_shape)).requires_grad_(True) + + gradcheck(operator, (coefficients,), fast_mode=True) + + +@pytest.mark.cuda +def test_non_uniform_fast_fourier_op_gram_cuda() -> None: + """Test Toeplitz Gram operators work on CUDA devices.""" + rng = RandomGenerator(seed=5) + n_timepoints = 4 + image_shape = (12, 10) + image = rng.complex64_tensor((n_timepoints, 1, 1, *image_shape)).cuda() + + nufft_op = create_time_varying_2d_nufft_op(n_timepoints=n_timepoints, image_shape=image_shape).cuda() + gram = nufft_op.gram.cuda() + (result,) = gram(image) + assert result.is_cuda + + nufft_op = create_time_varying_2d_nufft_op(n_timepoints=n_timepoints, image_shape=image_shape) + gram = nufft_op.gram + (result,) = gram(image) + assert result.is_cuda + + +@pytest.mark.cuda +def test_non_uniform_fast_fourier_op_subspace_gram_cuda() -> None: + """Test subspace Toeplitz Gram operators work on CUDA devices.""" + rng = RandomGenerator(seed=5) + n_timepoints, n_coefficients = 4, 2 + image_shape = (12, 10) + coefficients = rng.complex64_tensor((n_coefficients, 1, 1, *image_shape)).cuda() + basis = rng.complex64_tensor((n_timepoints, n_coefficients)) + + nufft_op = create_time_varying_2d_nufft_op(n_timepoints=n_timepoints, image_shape=image_shape) + subspace_gram = nufft_op.toeplitz(subspace=basis).cuda() + (result,) = subspace_gram(coefficients) + assert result.is_cuda + + nufft_op = create_time_varying_2d_nufft_op(n_timepoints=n_timepoints, image_shape=image_shape).cuda() + subspace_gram = nufft_op.toeplitz(subspace=basis.cuda()) + (result,) = subspace_gram(coefficients) + assert result.is_cuda + + @pytest.mark.cuda def test_non_uniform_fast_fourier_op_cuda() -> None: """Test non-uniform fast Fourier operator works on CUDA devices."""