Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/mrpro/operators/CirculantPreconditioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""A preconditioner for non-Cartesian iterative SENSE reconstruction."""

import torch

from mrpro.data.DcfData import DcfData
from mrpro.operators.LinearOperator import LinearOperator
from mrpro.operators.NonUniformFastFourierOp import NonUniformFastFourierOp


class CirculantPreconditioner(LinearOperator):
"""A preconditioner for a non-Cartesian SENSE reconstruction."""

def __init__(self, nufft_operator: NonUniformFastFourierOp, dcf: DcfData):
"""Initialize a circulant preconditioner for a non-Cartesian SENSE reconstruction.

This operator acts as a preconditioner P for iterative algorithms
solving Ax=b, where A involves NUFFT operations (F) andoptionally
coil sensitivities (C), e.g., A = C^H F^H F C.

The preconditioner approximates the inverse of the density-compensated
operator, P ≈ (F^H W F)^(-1), where W represents the density
compensation factors (DCF). It is constructed by:
1. Simulating the density-compensated Point Spread Function (PSF)
h_w = F^H W F(delta).
2. Computing the FFT of the PSF. These are the
eigenvalues of the circulant approximation.
3. Regularizing and inverting these eigenvalues to get the k-space
kernel

This preconditioner is suitable for accelerating solvers for the
*unweighted* least-squares problem (where A does not include W),
as it compensates for density variations internally.

Parameters
----------
nufft_operator
The non-uniform fast fourier transform operator.
dcf
density compensation weights
"""
super().__init__()
device = dcf.device if dcf.device is not None else nufft_operator._omega.device
im_shape_zyx = [1, 1, 1]
for dim, size in zip(nufft_operator._direction_zyx, nufft_operator._im_size, strict=True):
im_shape_zyx[dim + 3] = size

delta_image = torch.zeros((1, *im_shape_zyx), dtype=torch.complex64, device=device)
center_indices = tuple(size // 2 for size in im_shape_zyx)
delta_image[(0, *center_indices)] = 1.0
(k,) = nufft_operator(delta_image)
k = k * dcf.data
(psf,) = nufft_operator.adjoint(k)
kernel = torch.fft.fftn(torch.fft.ifftshift(psf, dim=(-1, -2, -3)), dim=(-1, -2, -3))
kernel = torch.polar(kernel.abs().clamp(min=1e-5), kernel.angle()).reciprocal()
self.kernel = kernel

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply the inverse of the preconditioner."""
x = torch.fft.fftn(x, dim=(-1, -2, -3))
x = x * self.kernel
x = torch.fft.ifftn(x, dim=(-1, -2, -3))
return (x,)

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply the adjoint of the inverse of the preconditioner."""
x = torch.fft.fftn(x, dim=(-1, -2, -3))
x = x * self.kernel.conj()
x = torch.fft.ifftn(x, dim=(-1, -2, -3))
return (x,)
2 changes: 2 additions & 0 deletions src/mrpro/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mrpro.operators import functionals, models
from mrpro.operators.AveragingOp import AveragingOp
from mrpro.operators.CartesianSamplingOp import CartesianSamplingOp, CartesianMaskingOp
from mrpro.operators.CirculantPreconditioner import CirculantPreconditioner
from mrpro.operators.ConvAnalysisDictionaryOp import ConvAnalysisDictionaryOp
from mrpro.operators.ConvSynthesisDictionaryOp import ConvSynthesisDictionaryOp
from mrpro.operators.ConjugateGradientOp import ConjugateGradientOp
Expand Down Expand Up @@ -56,6 +57,7 @@
"B0InformedFourierOp",
"CartesianMaskingOp",
"CartesianSamplingOp",
"CirculantPreconditioner",
"ConjugateGradientOp",
"ConjugatePhaseFourierOp",
"ConstraintsOp",
Expand Down
144 changes: 144 additions & 0 deletions tests/operators/test_circulant_preconditioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Tests for circulant preconditioner operator."""

import pytest
import torch
from mrpro.algorithms.optimizers import cg
from mrpro.data import DcfData
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.operators import CirculantPreconditioner, NonUniformFastFourierOp
from mrpro.utils import RandomGenerator

from tests import dotproduct_adjointness_test
from tests.conftest import create_traj


def create_circulant_preconditioner_and_domain_range() -> tuple[
CirculantPreconditioner,
NonUniformFastFourierOp,
DcfData,
torch.Tensor,
torch.Tensor,
]:
"""Create circulant preconditioner and random elements from domain and range."""
rng = RandomGenerator(seed=0)

img_shape = (1, 1, 1, 24, 24)
nkx = (1, 1, 1, 12, 24)
nky = (1, 1, 1, 12, 24)
nkz = (1, 1, 1, 1, 1)
traj = create_traj(nkx, nky, nkz, 'non-uniform', 'non-uniform', 'zero')

recon_matrix = SpatialDimension(img_shape[-3], img_shape[-2], img_shape[-1])
encoding_matrix = SpatialDimension(
int(traj.kz.max() - traj.kz.min() + 1),
int(traj.ky.max() - traj.ky.min() + 1),
int(traj.kx.max() - traj.kx.min() + 1),
)
direction = [d for d, e in zip(('z', 'y', 'x'), encoding_matrix.zyx, strict=False) if e > 1]
nufft_op = NonUniformFastFourierOp(
direction=direction, # type: ignore[arg-type]
recon_matrix=recon_matrix,
encoding_matrix=encoding_matrix,
traj=traj,
)

dcf = DcfData.from_traj_voronoi(traj)
circulant_preconditioner = CirculantPreconditioner(nufft_op, dcf)

u = rng.complex64_tensor(size=img_shape)
v = rng.complex64_tensor(size=img_shape)

return circulant_preconditioner, nufft_op, dcf, u, v


def test_circulant_preconditioner_adjointness() -> None:
"""Test adjoint property of circulant preconditioner."""
circulant_preconditioner, _, _, u, v = create_circulant_preconditioner_and_domain_range()
dotproduct_adjointness_test(circulant_preconditioner, u, v)


def test_circulant_preconditioner_cg_iteration() -> None:
"""Test circulant preconditioner in CG iterations."""
circulant_preconditioner, nufft_op, dcf, _, _ = create_circulant_preconditioner_and_domain_range()

operator = nufft_op.H @ dcf.as_operator() @ nufft_op
rng = RandomGenerator(seed=1)
true_solution = rng.complex64_tensor(size=(1, 1, 1, 24, 24))
(right_hand_side,) = operator(true_solution)

initial_value = torch.zeros_like(true_solution)
initial_residual = torch.linalg.vector_norm(right_hand_side.flatten())

(solution_without_preconditioner,) = cg(
operator,
right_hand_side,
initial_value=initial_value,
max_iterations=3,
tolerance=0,
)
residual_without_preconditioner = torch.linalg.vector_norm(
(operator(solution_without_preconditioner)[0] - right_hand_side).flatten()
)

(solution_with_preconditioner,) = cg(
operator,
right_hand_side,
initial_value=initial_value,
preconditioner_inverse=circulant_preconditioner,
max_iterations=3,
tolerance=0,
)
residual_with_preconditioner = torch.linalg.vector_norm(
(operator(solution_with_preconditioner)[0] - right_hand_side).flatten()
)

assert residual_without_preconditioner < initial_residual
assert residual_with_preconditioner < initial_residual


@pytest.mark.cuda
def test_circulant_preconditioner_cuda() -> None:
"""Test circulant preconditioner works on CUDA devices."""
rng = RandomGenerator(seed=2)
x = rng.complex64_tensor(size=(1, 1, 1, 24, 24))

# Create on CPU, transfer to GPU, run on GPU
preconditioner, _, _, _, _ = create_circulant_preconditioner_and_domain_range()
preconditioner.cuda()
(result,) = preconditioner(x.cuda())
assert result.is_cuda

# Create on CPU, run on CPU
preconditioner, _, _, _, _ = create_circulant_preconditioner_and_domain_range()
(result,) = preconditioner(x)
assert result.is_cpu

# Create on GPU, run on GPU
img_shape = (1, 1, 1, 24, 24)
nkx = (1, 1, 1, 12, 24)
nky = (1, 1, 1, 12, 24)
nkz = (1, 1, 1, 1, 1)
traj = create_traj(nkx, nky, nkz, 'non-uniform', 'non-uniform', 'zero').cuda()

recon_matrix = SpatialDimension(img_shape[-3], img_shape[-2], img_shape[-1])
encoding_matrix = SpatialDimension(
int(traj.kz.max() - traj.kz.min() + 1),
int(traj.ky.max() - traj.ky.min() + 1),
int(traj.kx.max() - traj.kx.min() + 1),
)
direction = [d for d, e in zip(('z', 'y', 'x'), encoding_matrix.zyx, strict=False) if e > 1]
nufft_op = NonUniformFastFourierOp(
direction=direction, # type: ignore[arg-type]
recon_matrix=recon_matrix,
encoding_matrix=encoding_matrix,
traj=traj,
)
dcf = DcfData.from_traj_voronoi(traj)
preconditioner = CirculantPreconditioner(nufft_op, dcf)
(result,) = preconditioner(x.cuda())
assert result.is_cuda

# Create on GPU, transfer to CPU, run on CPU
preconditioner.cpu()
(result,) = preconditioner(x)
assert result.is_cpu
Loading