Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 232 additions & 11 deletions src/mrpro/operators/NonUniformFastFourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from types import EllipsisType
from typing import Literal

import einops
import numpy as np
import torch
from pytorch_finufft.functional import finufft_type1, finufft_type2
Expand All @@ -15,6 +16,8 @@
from mrpro.data.KTrajectory import KTrajectory
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.operators.LinearOperator import LinearOperator
from mrpro.operators.PCACompressionOp import PCACompressionOp
from mrpro.utils import normalize_index, unsqueeze_right


class NonUniformFastFourierOp(LinearOperator, adjoint_as_backward=True):
Expand Down Expand Up @@ -200,12 +203,14 @@ def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
Parameters
----------
x
Coil image data, typically with shape `(..., coils, z, y, x)`.
Image-space data with selected spatial dimensions as trailing axes, e.g.
`(..., coils, z, y, x)` for 3D or `(..., coils, y, x)` for 2D.

Returns
-------
Coil k-space data at non-uniform locations,
with shape `(..., coils, k2, k1, k0)`.
K-space data at the non-uniform trajectory locations with the same leading
dimensions as the input and trailing sampled dimensions, e.g.
`(..., coils, k2, k1, k0)` for 3D or `(..., coils, k1, k0)` for 2D.
"""
return super().__call__(x)

Expand Down Expand Up @@ -282,10 +287,35 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
x = x.permute(*unpermute_zyx)
return (x,)

def toeplitz(
self,
weight: torch.Tensor | DcfData | None = None,
subspace: torch.Tensor | PCACompressionOp | None = None,
subspace_dim: int = 0,
) -> LinearOperator:
"""Return the Toeplitz Gram operator.

Parameters
----------
weight
Optional density compensation or k-space weights. If provided, calculates
the normal operator corresponding to :math:`F^H W F`.
subspace
Optional temporal subspace basis of shape `(n_timepoints, n_coefficients)`, or
a `PCACompressionOp` whose adjoint defines the temporal expansion basis.
subspace_dim
Dimension of the coefficient channel in the input/output tensors when `subspace` is provided.
"""
if subspace is None:
return NonUniformFastFourierOpGramOp(self, weight=weight)
return SubspaceNonUniformFastFourierOpGramOp(
self, subspace_basis=subspace, weight=weight, subspace_dim=subspace_dim
)

@property
def gram(self) -> LinearOperator:
"""Return the gram operator."""
return NonUniformFastFourierOpGramOp(self)
"""Return the unweighted Toeplitz Gram operator."""
return self.toeplitz()

def __repr__(self) -> str:
"""Representation method for NUFFT operator."""
Expand Down Expand Up @@ -404,17 +434,14 @@ def __init__(self, nufft_op: NonUniformFastFourierOp, weight: torch.Tensor | Dcf

if isinstance(weight, DcfData):
weight = weight.data

weight_shape = [*nufft_op._traj_broadcast_shape[:-4], 1, *nufft_op._traj_broadcast_shape[-3:]]

if weight is None:
weight = torch.ones(weight_shape, device=nufft_op._omega.device)
weight = torch.ones(weight_shape, device=nufft_op._omega.device, dtype=nufft_op._omega.dtype.to_real())
else:
weight = weight.to(nufft_op._omega.device).broadcast_to(weight_shape)
# We rearrange weight into (sep_dims, joint_dims, nufft_directions)
_, permute_zyx, sep_dims_210, permute_210 = nufft_op._separate_joint_dimensions(weight.ndim)
unpermute_zyx = torch.tensor(permute_zyx).argsort().tolist()

weight = weight.permute(*permute_210)
unflatten_other_shape = weight.shape[: -len(nufft_op._dimension_210) - 1] # -1 for coil
if sep_dims_210: # combine sep_dims
Expand All @@ -428,6 +455,14 @@ def __init__(self, nufft_op: NonUniformFastFourierOp, weight: torch.Tensor | Dcf
kernel = kernel.reshape(*unflatten_other_shape, -1, *kernel.shape[-len(nufft_op._direction_zyx) :])
self._kernel = kernel.permute(*unpermute_zyx) * nufft_op.scale.item() ** 2

@property
def recon_matrix(self) -> SpatialDimension[int]:
"""Expected reconstructed image shape as a spatial dimension."""
zyx = [1, 1, 1]
for direction, size in zip(self._dim, self._recon_shape, strict=True):
zyx[direction] = size
return SpatialDimension(*zyx)

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply forward of NonUniformFastFourierOpGramOp.

Expand Down Expand Up @@ -472,7 +507,7 @@ def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor,]:

Returns
-------
Output tensor, image-space data after NUFFT.H @ NUFFT has been applied.
Image-space data after applying `NUFFT.H @ NUFFT`, with the same shape as `x`.
"""
return super().__call__(x)

Expand All @@ -489,7 +524,193 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:

Returns
-------
Output tensor, same shape as the input.
Image-space data after applying the adjoint Gram operator, with the same shape as `x`.
"""
return super().__call__(x)

@property
def H(self) -> Self: # noqa: N802
"""Adjoint operator of the gram operator."""
return self


class SubspaceNonUniformFastFourierOpGramOp(LinearOperator, adjoint_as_backward=True):
"""Subspace Gram operator for `NonUniformFastFourierOp`.

This operator acts on temporally compressed coefficient images and applies the
compressed normal operator corresponding to a time-varying non-uniform Fourier
encoding. The Toeplitz kernel is precomputed once during initialization.

The current implementation is intentionally restricted to the common MRF case:
exactly one varying separate trajectory dimension, interpreted as time, and no
additional non-singleton joint trajectory dimensions.
"""

_kernel: torch.Tensor | None

def __init__(
self,
nufft_op: NonUniformFastFourierOp,
subspace_basis: torch.Tensor | PCACompressionOp,
weight: torch.Tensor | DcfData | None = None,
subspace_dim: int = 0,
) -> None:
"""Initialize the subspace Gram operator.

Parameters
----------
nufft_op
The non-uniform Fourier operator defining the trajectory and spatial geometry.
subspace_basis
Temporal subspace basis of shape `(n_timepoints, n_coefficients)`, or
a `PCACompressionOp` whose adjoint defines the temporal expansion basis.
weight
Optional density compensation or k-space weights. If provided, calculates
the compressed normal operator corresponding to :math:`F^H W F`.
subspace_dim
Dimension of the coefficient channel in the input/output tensors. Spatial
dimensions are assumed to remain trailing, matching the usual MR2 image layout.
"""
super().__init__()
self._dim = nufft_op._direction_zyx
self._recon_shape = nufft_op._im_size
self._subspace_dim = subspace_dim

if not nufft_op._dimension_210:
self._kernel = None
return

if isinstance(subspace_basis, PCACompressionOp):
basis = subspace_basis._compression_matrix.mH
else:
basis = subspace_basis
basis = basis.squeeze()
if basis.ndim > 2:
raise ValueError(f'Basis cannot contain non-singleton batch dimensions; got squeezed shape {basis.shape}.')
elif basis.ndim == 1: # rank-1 special case, we squeezed the singleton subspace dimension
basis = basis.unsqueeze(-1)
basis = basis.to(device=nufft_op._omega.device)

if isinstance(weight, DcfData):
weight = weight.data
weight_shape = [*nufft_op._traj_broadcast_shape[:-4], 1, *nufft_op._traj_broadcast_shape[-3:]]
if weight is None:
weight = torch.ones(weight_shape, device=nufft_op._omega.device, dtype=nufft_op._omega.dtype.to_real())
else:
weight = weight.to(nufft_op._omega.device).broadcast_to(weight_shape)
_, _permute_zyx, sep_dims_210, permute_210 = nufft_op._separate_joint_dimensions(weight.ndim)
weight = weight.permute(*permute_210)
if sep_dims_210:
weight = weight.flatten(end_dim=len(sep_dims_210) - 1)
else:
weight = weight[None, :]
weight = weight.flatten(start_dim=1, end_dim=-len(nufft_op._dimension_210) - 1).flatten(start_dim=2)

omega = nufft_op._omega

if len(sep_dims_210) > 1:
raise NotImplementedError(
'SubspaceNonUniformFastFourierOpGramOp currently only supports a single varying separate dimension.'
)
if basis.shape[0] != (n_timepoints := weight.shape[0]):
raise ValueError(
f'subspace_basis has {basis.shape[0]} time points, but the trajectory varies over {n_timepoints}.'
)

subspace_kernel: torch.Tensor | None = None
for basis_row, weight_row, omega_row in zip(basis, weight, omega, strict=True):
current = gram_nufft_kernel(weight_row[None], omega_row[None], nufft_op._im_size)[0, 0]
coeff_outer = basis_row.conj()[:, None] * basis_row[None, :]
term = unsqueeze_right(coeff_outer, current.ndim) * current
subspace_kernel = term if subspace_kernel is None else subspace_kernel + term

assert subspace_kernel is not None # noqa: S101
self._kernel = subspace_kernel * nufft_op.scale.item() ** 2

@property
def n_coefficients(self) -> int:
"""Number of compressed coefficient channels."""
if self._kernel is None:
raise RuntimeError('n_coefficients is undefined before the Toeplitz kernel is initialized.')
return self._kernel.shape[0]

@property
def recon_matrix(self) -> SpatialDimension[int]:
"""Expected reconstructed image shape as a spatial dimension."""
zyx = [1, 1, 1]
for direction, size in zip(self._dim, self._recon_shape, strict=True):
zyx[direction] = size
return SpatialDimension(*zyx)

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply forward of SubspaceNonUniformFastFourierOpGramOp.

.. note::
Prefer calling the instance of the SubspaceNonUniformFastFourierOpGramOp operator as ``operator(x)`` over
directly calling this method. See this PyTorch `discussion <https://discuss.pytorch.org/t/is-model-forward-x-the-same-as-model-call-x/33460/3>`_.
"""
# We do the fft, coefficient mixing, and cropping here directly, as it is faster and more memory efficient
# than using separate operators.

# This function on its own is also not autograd save (in-place ops), but we rely on the
# adjoint-as-backward trick for linear operators to make it work.
if self._kernel is None:
return (x,)
subspace_dim = normalize_index(x.ndim, self._subspace_dim)
if x.shape[subspace_dim] != self.n_coefficients:
raise ValueError(
f'Input has {x.shape[self._subspace_dim]} coefficients along subspace_dim={self._subspace_dim}, '
f'expected {self.n_coefficients}.'
)
x = x.movedim(subspace_dim, 0)

spatial_dims = tuple(range(x.ndim - len(self._dim), x.ndim))
padded_shape = [2 * s for s in self._recon_shape]
spatial_crop: list[slice] = [slice(None)] * x.ndim
for d, s in zip(spatial_dims, self._recon_shape, strict=True):
spatial_crop[d] = slice(0, s)

x = torch.fft.fftn(x, s=padded_shape, dim=spatial_dims)
x = einops.einsum(self._kernel.to(x.dtype), x, 'coeff_out coeff_in ..., coeff_in ... -> coeff_out ...')
x = torch.fft.ifftn(x, dim=spatial_dims)
x = x[tuple(spatial_crop)]
out = x.clone().movedim(0, subspace_dim)
return (out,)

def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply the compressed Toeplitz Gram operator.

Parameters
----------
x
Subspace coefficient images. The coefficient axis is given by `subspace_dim`;
the selected spatial dimensions are trailing axes, e.g. `(coeff, z, y, x)`
for 3D or `(coeff, y, x)` for 2D when `subspace_dim=0`.

Returns
-------
Subspace coefficient images after applying the compressed Gram operator,
with the same shape as `x`.
"""
return super().__call__(x)

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply the adjoint of the compressed Toeplitz Gram operator.

Since the compressed Gram operator is self-adjoint, this method calls the
forward operation.

Parameters
----------
x
Subspace coefficient images. The coefficient axis is given by `subspace_dim`;
the selected spatial dimensions are trailing axes, e.g. `(coeff, z, y, x)`
for 3D or `(coeff, y, x)` for 2D when `subspace_dim=0`.

Returns
-------
Subspace coefficient images after applying the adjoint compressed Gram operator,
with the same shape as `x`.
"""
return super().__call__(x)

Expand Down
6 changes: 5 additions & 1 deletion src/mrpro/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
from mrpro.operators.LinearOperatorMatrix import LinearOperatorMatrix
from mrpro.operators.MagnitudeOp import MagnitudeOp
from mrpro.operators.MultiIdentityOp import MultiIdentityOp
from mrpro.operators.NonUniformFastFourierOp import NonUniformFastFourierOp
from mrpro.operators.NonUniformFastFourierOp import (
NonUniformFastFourierOp,
SubspaceNonUniformFastFourierOpGramOp,
)
from mrpro.operators.OptimizerOp import OptimizerOp
from mrpro.operators.PatchOp import PatchOp
from mrpro.operators.PCACompressionOp import PCACompressionOp
Expand Down Expand Up @@ -91,6 +94,7 @@
"SensitivityOp",
"SignalModel",
"SliceProjectionOp",
"SubspaceNonUniformFastFourierOpGramOp",
"TimeSegmentedFourierOp",
"WaveletOp",
"ZeroOp",
Expand Down
Loading
Loading