diff --git a/src/mrpro/nn/__init__.py b/src/mrpro/nn/__init__.py index d6541f5c8..f988855e7 100644 --- a/src/mrpro/nn/__init__.py +++ b/src/mrpro/nn/__init__.py @@ -12,6 +12,7 @@ from mrpro.nn.RMSNorm import RMSNorm from mrpro.nn.Residual import Residual from mrpro.nn.Sequential import Sequential +from mrpro.nn import data_consistency from mrpro.nn.ndmodules import ( adaptiveAvgPoolND, avgPoolND, @@ -40,6 +41,7 @@ 'batchNormND', 'convND', 'convTransposeND', + 'data_consistency', 'instanceNormND', 'maxPoolND', ] diff --git a/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py new file mode 100644 index 000000000..5b2848fca --- /dev/null +++ b/src/mrpro/nn/data_consistency/AnalyticCartesianDC.py @@ -0,0 +1,101 @@ +"""Analytic Cartesian data consistency.""" + +from typing import overload + +import torch +from torch.nn import Module, Parameter + +from mrpro.data.KData import KData +from mrpro.operators.FourierOp import FourierOp +from mrpro.operators.IdentityOp import IdentityOp + + +class AnalyticCartesianDC(Module): + r"""Analytic Cartesian data consistency. + + Solves the following problem: + :math:`\min_x \|Ax - k\|_2^2 + \lambda \|x-p\|_2^2` + where :math:`A` is the acquisition operator and :math:`k` is the data, :math:`\lambda` is the regularization + parameter and :math:`p` is the regularization image/prior analytically. :math:`A^H A` has to be diagonal. This is a + special case for a Cartesian acquisition without coil sensitivity weighting. This can be used for either single-coil + data or to apply data consistency to each coil image [NOSENSE]_. + + References + ---------- + .. [NOSENSE] Zimmermann, FF, and Kofler, Andreas. "NoSENSE: Learned unrolled cardiac MRI reconstruction without + explicit sensitivity maps." STACOM@MICCAI 2023. https://arxiv.org/abs/2309.15608 + + Parameters + ---------- + initial_regularization_weight + Initial regularization weight. The regularization weight is a trainable parameter. + + + """ + + def __init__(self, initial_regularization_weight: torch.Tensor | float): + """Initialize the data consistency. + + Parameters + ---------- + initial_regularization_weight + Initial regularization weight. The regularization weight is a trainable parameter. + Must be a positive scalar. + """ + super().__init__() + weight = torch.as_tensor(initial_regularization_weight) + if weight.ndim != 0: + raise ValueError('Regularization weight must be a scalar') + if weight.item() <= 0: + raise ValueError('Regularization weight must be positive') + self.log_weight = Parameter(weight.log()) + + @overload + def __call__(self, image: torch.Tensor, data: KData, fourier_op: FourierOp | None = None) -> torch.Tensor: ... + + @overload + def __call__(self, image: torch.Tensor, data: torch.Tensor, fourier_op: FourierOp) -> torch.Tensor: ... + + def __call__( + self, image: torch.Tensor, data: KData | torch.Tensor, fourier_op: FourierOp | None = None + ) -> torch.Tensor: + """Apply the data consistency. + + Parameters + ---------- + image + Current image estimate, i.e. the regularized image. + data + k-space data. + fourier_op + Fourier operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object, + the Fourier operator is automatically created from the data. + + Returns + ------- + Updated image estimate. + """ + return super().__call__(image, data, fourier_op) + + def forward( + self, image: torch.Tensor, data: KData | torch.Tensor, fourier_op: FourierOp | None = None + ) -> torch.Tensor: + """Apply the data consistency.""" + if fourier_op is None: + if isinstance(data, KData): + fourier_op = FourierOp.from_kdata(data) + else: + raise ValueError('Either a KData or a FourierOp is required') + + if not isinstance(fourier_op, FourierOp) or fourier_op._nufft_dims or fourier_op._fast_fourier_op is None: + raise ValueError('Only Cartesian acquisitions are supported') + + data_ = data.data if isinstance(data, KData) else data + fft_op = fourier_op._fast_fourier_op + sampling_op = fourier_op._cart_sampling_op if fourier_op._cart_sampling_op is not None else IdentityOp() + (zero_filled,) = sampling_op.adjoint(data_) + (k_pred,) = fft_op(image) + regularization_weight = self.log_weight.exp() + (k,) = sampling_op.gram((zero_filled - k_pred) / (1 + regularization_weight)) + (delta,) = fft_op.H(k) + return image + delta diff --git a/src/mrpro/nn/data_consistency/ConjugateGradientDC.py b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py new file mode 100644 index 000000000..f81b24f98 --- /dev/null +++ b/src/mrpro/nn/data_consistency/ConjugateGradientDC.py @@ -0,0 +1,117 @@ +"""Conjugate gradient data consistency.""" + +from typing import overload + +import torch +from torch.nn import Module, Parameter + +from mrpro.data.CsmData import CsmData +from mrpro.data.KData import KData +from mrpro.operators.ConjugateGradientOp import ConjugateGradientOp +from mrpro.operators.FourierOp import FourierOp +from mrpro.operators.LinearOperator import LinearOperator +from mrpro.operators.SensitivityOp import SensitivityOp + + +class ConjugateGradientDC(Module): + """Conjugate gradient data consistency.""" + + def __init__(self, initial_regularization_weight: torch.Tensor | float): + """Initialize the conjugate gradient data consistency. + + Parameters + ---------- + initial_regularization_weight + Initial regularization weight. The regularization weight is a trainable parameter. + Must be a positive scalar. + """ + super().__init__() + weight = torch.as_tensor(initial_regularization_weight) + if weight.ndim != 0: + raise ValueError('Regularization weight must be a scalar') + if weight.item() <= 0: + raise ValueError('Regularization weight must be positive') + self.log_weight = Parameter(weight.log()) + + @overload + def __call__( + self, + image: torch.Tensor, + data: KData, + fourier_op: FourierOp | None = None, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: ... + + @overload + def __call__( + self, + image: torch.Tensor, + data: torch.Tensor, + fourier_op: FourierOp, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: ... + + def __call__( + self, + image: torch.Tensor, + data: KData | torch.Tensor, + fourier_op: LinearOperator | None = None, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: + """Apply the data consistency. + + Parameters + ---------- + image + Current image estimate. + data + k-space data. + fourier_op + Fourier operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object, + the Fourier operator is automatically created from the data. + This operator can already include the coil sensitivity weighting, if gradients wrt the coil sensitivity maps + NOT required. Otherwise, they should be given as an additional argument. + csm + Coil sensitivity maps. If None, no coil sensitivity weighting is applied. + + Returns + ------- + Updated image estimate. + """ + return super().__call__(image, data, fourier_op, csm) + + def forward( + self, + image: torch.Tensor, + data: torch.Tensor | KData, + fourier_op: FourierOp | None = None, + csm: torch.Tensor | CsmData | None = None, + ) -> torch.Tensor: + """Apply the data consistency.""" + if fourier_op is None: + if isinstance(data, KData): + fourier_op = FourierOp.from_kdata(data) + else: + raise ValueError('Either a KData or a FourierOp is required') + + data_ = data.data if isinstance(data, KData) else data + + if csm is None: + csm = torch.tensor(()) + elif isinstance(csm, CsmData): + csm = csm.data + + def operator_factory(csm: torch.Tensor, weight: torch.Tensor, *_): + op = fourier_op.gram + if csm.numel(): + csm_op = SensitivityOp(csm) + op = csm_op.H @ op @ csm_op + op = op + weight + return op + + def rhs_factory(_csm: torch.Tensor, weight: torch.Tensor, zero_filled: torch.Tensor, image: torch.Tensor): + return (zero_filled + weight * image,) + + cg_op = ConjugateGradientOp(operator_factory=operator_factory, rhs_factory=rhs_factory) + (result,) = cg_op(csm, self.log_weight.exp(), fourier_op.adjoint(data_)[0], image) + return result diff --git a/src/mrpro/nn/data_consistency/GradientDescentDC.py b/src/mrpro/nn/data_consistency/GradientDescentDC.py new file mode 100644 index 000000000..579e5181a --- /dev/null +++ b/src/mrpro/nn/data_consistency/GradientDescentDC.py @@ -0,0 +1,99 @@ +"""Gradient descent data consistency.""" + +from typing import overload + +import torch +from torch.nn import Module, Parameter + +from mrpro.data.KData import KData +from mrpro.operators.FourierOp import FourierOp +from mrpro.operators.LinearOperator import LinearOperator + + +class GradientDescentDC(Module): + r"""Gradient descent data consistency. + + Performs gradient descent steps on + :math:`\|Ax - k\|_2^2` where :math:`A` is the acquisition operator and :math:`k` is the data. + + Parameters + ---------- + initial_stepsize + Initial stepsize. The stepsize is a trainable parameter. + Must be a positive scalar. + n_steps + Number of gradient descent steps. + + Returns + ------- + The updated image. + """ + + def __init__(self, initial_stepsize: float | torch.Tensor, n_steps: int = 1) -> None: + """Initialize the gradient descent data consistency. + + Parameters + ---------- + initial_stepsize + Initial stepsize. The stepsize is a trainable parameter. + Must be a positive scalar. + n_steps + Number of gradient descent steps. + """ + super().__init__() + stepsize = torch.as_tensor(initial_stepsize) + if stepsize.ndim != 0: + raise ValueError('Stepsize must be a scalar') + if stepsize.item() <= 0: + raise ValueError('Stepsize must be positive') + self.log_stepsize = Parameter(stepsize.log()) + self.n_steps = n_steps + + @overload + def __call__( + self, image: torch.Tensor, data: KData, acquisition_operator: LinearOperator | None = None + ) -> torch.Tensor: ... + + @overload + def __call__( + self, image: torch.Tensor, data: torch.Tensor, acquisition_operator: LinearOperator + ) -> torch.Tensor: ... + + def __call__( + self, image: torch.Tensor, data: KData | torch.Tensor, acquisition_operator: LinearOperator | None = None + ) -> torch.Tensor: + """Apply the data consistency. + + Parameters + ---------- + image + Current image estimate. + data + k-space data. + acquisition_operator + Acquisition operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` + object, the Fourier operator is automatically created from the data. + + Returns + ------- + Updated image estimate. + """ + return super().__call__(image, data, acquisition_operator) + + def forward( + self, image: torch.Tensor, data: KData | torch.Tensor, acquisition_operator: LinearOperator | None = None + ) -> torch.Tensor: + """Apply the data consistency.""" + if acquisition_operator is None: + if isinstance(data, KData): + acquisition_operator = FourierOp.from_kdata(data) + else: + raise ValueError('Either a KData or an acquisition operator is required') + + data_ = data.data if isinstance(data, KData) else data + stepsize = self.log_stepsize.exp() + x = image + for _ in range(self.n_steps): + residual = acquisition_operator(x)[0] - data_ + x = x - stepsize * acquisition_operator.adjoint(residual)[0] + return x diff --git a/src/mrpro/nn/data_consistency/__init__.py b/src/mrpro/nn/data_consistency/__init__.py new file mode 100644 index 000000000..cea955c85 --- /dev/null +++ b/src/mrpro/nn/data_consistency/__init__.py @@ -0,0 +1,5 @@ +from mrpro.nn.data_consistency.AnalyticCartesianDC import AnalyticCartesianDC +from mrpro.nn.data_consistency.GradientDescentDC import GradientDescentDC +from mrpro.nn.data_consistency.ConjugateGradientDC import ConjugateGradientDC + +__all__ = ["AnalyticCartesianDC", "ConjugateGradientDC", "GradientDescentDC"] \ No newline at end of file diff --git a/tests/nn/data_consistency/conftest.py b/tests/nn/data_consistency/conftest.py new file mode 100644 index 000000000..49fca4cf3 --- /dev/null +++ b/tests/nn/data_consistency/conftest.py @@ -0,0 +1,46 @@ +"""Test fixtures for data consistency tests.""" + +import pytest +from mrpro.data.KData import KData +from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.data.traj_calculators.KTrajectoryCartesian import KTrajectoryCartesian +from mrpro.operators.FourierOp import FourierOp +from mrpro.phantoms.EllipsePhantom import EllipsePhantom +from mrpro.utils import RandomGenerator + + +@pytest.fixture +def kdata(): + matrix = SpatialDimension(x=128, y=128, z=1) + kdata = EllipsePhantom().kdata(KTrajectoryCartesian.fullysampled(matrix), matrix) + return kdata + + +@pytest.fixture +def kdata_noisy(kdata: KData): + kdata_noisy = kdata.clone() + kdata_noisy.data += 0.1 * RandomGenerator(123).randn_like(kdata_noisy.data) + return kdata_noisy + + +@pytest.fixture +def kdata_us(kdata: KData): + return kdata[..., ::2, :].clone() + + +@pytest.fixture +def image_noisy(kdata_noisy: KData): + fourier_op = FourierOp.from_kdata(kdata_noisy) + return fourier_op.adjoint(kdata_noisy.data)[0] + + +@pytest.fixture +def image(kdata: KData): + fourier_op = FourierOp.from_kdata(kdata) + return fourier_op.adjoint(kdata.data)[0] + + +@pytest.fixture +def image_us(kdata_us: KData): + fourier_op = FourierOp.from_kdata(kdata_us) + return fourier_op.adjoint(kdata_us.data)[0] diff --git a/tests/nn/data_consistency/test_analyticcartesiandc.py b/tests/nn/data_consistency/test_analyticcartesiandc.py new file mode 100644 index 000000000..9020a8326 --- /dev/null +++ b/tests/nn/data_consistency/test_analyticcartesiandc.py @@ -0,0 +1,21 @@ +"""Tests for AnalyticCartesianDC module.""" + +import torch +from mrpro.data.KData import KData +from mrpro.nn.data_consistency.AnalyticCartesianDC import AnalyticCartesianDC + + +def test_analytic_cartesian_dc( + image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor +) -> None: + image_noisy = image_noisy.clone().requires_grad_(True) + dc = AnalyticCartesianDC(initial_regularization_weight=1e-6) + result = dc(image_noisy, kdata_us) + loss = (result - image).abs().mean() + assert loss < (image_noisy - image).abs().mean() + assert loss < (image_us - image).abs().mean() + loss.backward() + assert image_noisy.grad is not None + assert dc.log_weight.grad is not None + assert not dc.log_weight.grad.isnan() + assert not image_noisy.grad.isnan().any() diff --git a/tests/nn/data_consistency/test_conjugategradientdc.py b/tests/nn/data_consistency/test_conjugategradientdc.py new file mode 100644 index 000000000..99de6f4f5 --- /dev/null +++ b/tests/nn/data_consistency/test_conjugategradientdc.py @@ -0,0 +1,21 @@ +"""Tests for ConjugateGradientDC module.""" + +import torch +from mrpro.data.KData import KData +from mrpro.nn.data_consistency.ConjugateGradientDC import ConjugateGradientDC + + +def test_conjugate_gradient_dc( + image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor +) -> None: + image_noisy = image_noisy.clone().requires_grad_(True) + dc = ConjugateGradientDC(initial_regularization_weight=1.0) + result = dc(image_noisy, kdata_us) + loss = (result - image).abs().mean() + assert loss < (image_noisy - image).abs().mean() + assert loss < (image_us - image).abs().mean() + loss.backward() + assert image_noisy.grad is not None + assert dc.log_weight.grad is not None + assert not dc.log_weight.grad.isnan() + assert not image_noisy.grad.isnan().any() diff --git a/tests/nn/data_consistency/test_gradientdescentdc.py b/tests/nn/data_consistency/test_gradientdescentdc.py new file mode 100644 index 000000000..00a7648f9 --- /dev/null +++ b/tests/nn/data_consistency/test_gradientdescentdc.py @@ -0,0 +1,21 @@ +"""Tests for GradientDescentDC module.""" + +import torch +from mrpro.data.KData import KData +from mrpro.nn.data_consistency.GradientDescentDC import GradientDescentDC + + +def test_gradient_descent_dc( + image_noisy: torch.Tensor, kdata_us: KData, image: torch.Tensor, image_us: torch.Tensor +) -> None: + image_noisy = image_noisy.clone().requires_grad_(True) + dc = GradientDescentDC(initial_stepsize=1.0) + result = dc(image_noisy, kdata_us) + loss = (result - image).abs().mean() + assert loss < (image_noisy - image).abs().mean() + assert loss < (image_us - image).abs().mean() + loss.backward() + assert image_noisy.grad is not None + assert dc.log_stepsize.grad is not None + assert not dc.log_stepsize.grad.isnan() + assert not image_noisy.grad.isnan().any()