Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mrpro/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mrpro.nn.RMSNorm import RMSNorm
from mrpro.nn.Residual import Residual
from mrpro.nn.Sequential import Sequential
from mrpro.nn import data_consistency
from mrpro.nn.ndmodules import (
adaptiveAvgPoolND,
avgPoolND,
Expand Down Expand Up @@ -40,6 +41,7 @@
'batchNormND',
'convND',
'convTransposeND',
'data_consistency',
'instanceNormND',
'maxPoolND',
]
101 changes: 101 additions & 0 deletions src/mrpro/nn/data_consistency/AnalyticCartesianDC.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Analytic Cartesian data consistency."""

from typing import overload

import torch
from torch.nn import Module, Parameter

from mrpro.data.KData import KData
from mrpro.operators.FourierOp import FourierOp
from mrpro.operators.IdentityOp import IdentityOp


class AnalyticCartesianDC(Module):
r"""Analytic Cartesian data consistency.

Solves the following problem:
:math:`\min_x \|Ax - k\|_2^2 + \lambda \|x-p\|_2^2`
where :math:`A` is the acquisition operator and :math:`k` is the data, :math:`\lambda` is the regularization
parameter and :math:`p` is the regularization image/prior analytically. :math:`A^H A` has to be diagonal. This is a
special case for a Cartesian acquisition without coil sensitivity weighting. This can be used for either single-coil
data or to apply data consistency to each coil image [NOSENSE]_.

References
----------
.. [NOSENSE] Zimmermann, FF, and Kofler, Andreas. "NoSENSE: Learned unrolled cardiac MRI reconstruction without
explicit sensitivity maps." STACOM@MICCAI 2023. https://arxiv.org/abs/2309.15608

Parameters
----------
initial_regularization_weight
Initial regularization weight. The regularization weight is a trainable parameter.


"""

def __init__(self, initial_regularization_weight: torch.Tensor | float):
"""Initialize the data consistency.

Parameters
----------
initial_regularization_weight
Initial regularization weight. The regularization weight is a trainable parameter.
Must be a positive scalar.
"""
super().__init__()
weight = torch.as_tensor(initial_regularization_weight)
if weight.ndim != 0:
raise ValueError('Regularization weight must be a scalar')
if weight.item() <= 0:
raise ValueError('Regularization weight must be positive')
self.log_weight = Parameter(weight.log())

@overload
def __call__(self, image: torch.Tensor, data: KData, fourier_op: FourierOp | None = None) -> torch.Tensor: ...

@overload
def __call__(self, image: torch.Tensor, data: torch.Tensor, fourier_op: FourierOp) -> torch.Tensor: ...

def __call__(
self, image: torch.Tensor, data: KData | torch.Tensor, fourier_op: FourierOp | None = None
) -> torch.Tensor:
"""Apply the data consistency.

Parameters
----------
image
Current image estimate, i.e. the regularized image.
data
k-space data.
fourier_op
Fourier operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object,
the Fourier operator is automatically created from the data.

Returns
-------
Updated image estimate.
"""
return super().__call__(image, data, fourier_op)

def forward(
self, image: torch.Tensor, data: KData | torch.Tensor, fourier_op: FourierOp | None = None
) -> torch.Tensor:
"""Apply the data consistency."""
if fourier_op is None:
if isinstance(data, KData):
fourier_op = FourierOp.from_kdata(data)
else:
raise ValueError('Either a KData or a FourierOp is required')

if not isinstance(fourier_op, FourierOp) or fourier_op._nufft_dims or fourier_op._fast_fourier_op is None:
raise ValueError('Only Cartesian acquisitions are supported')

data_ = data.data if isinstance(data, KData) else data
fft_op = fourier_op._fast_fourier_op
sampling_op = fourier_op._cart_sampling_op if fourier_op._cart_sampling_op is not None else IdentityOp()
(zero_filled,) = sampling_op.adjoint(data_)
(k_pred,) = fft_op(image)
regularization_weight = self.log_weight.exp()
(k,) = sampling_op.gram((zero_filled - k_pred) / (1 + regularization_weight))
(delta,) = fft_op.H(k)
return image + delta
117 changes: 117 additions & 0 deletions src/mrpro/nn/data_consistency/ConjugateGradientDC.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Conjugate gradient data consistency."""

from typing import overload

import torch
from torch.nn import Module, Parameter

from mrpro.data.CsmData import CsmData
from mrpro.data.KData import KData
from mrpro.operators.ConjugateGradientOp import ConjugateGradientOp
from mrpro.operators.FourierOp import FourierOp
from mrpro.operators.LinearOperator import LinearOperator
from mrpro.operators.SensitivityOp import SensitivityOp


class ConjugateGradientDC(Module):
"""Conjugate gradient data consistency."""

def __init__(self, initial_regularization_weight: torch.Tensor | float):
"""Initialize the conjugate gradient data consistency.

Parameters
----------
initial_regularization_weight
Initial regularization weight. The regularization weight is a trainable parameter.
Must be a positive scalar.
"""
super().__init__()
weight = torch.as_tensor(initial_regularization_weight)
if weight.ndim != 0:
raise ValueError('Regularization weight must be a scalar')
if weight.item() <= 0:
raise ValueError('Regularization weight must be positive')
self.log_weight = Parameter(weight.log())

@overload
def __call__(
self,
image: torch.Tensor,
data: KData,
fourier_op: FourierOp | None = None,
csm: torch.Tensor | CsmData | None = None,
) -> torch.Tensor: ...

@overload
def __call__(
self,
image: torch.Tensor,
data: torch.Tensor,
fourier_op: FourierOp,
csm: torch.Tensor | CsmData | None = None,
) -> torch.Tensor: ...

def __call__(
self,
image: torch.Tensor,
data: KData | torch.Tensor,
fourier_op: LinearOperator | None = None,
csm: torch.Tensor | CsmData | None = None,
) -> torch.Tensor:
"""Apply the data consistency.

Parameters
----------
image
Current image estimate.
data
k-space data.
fourier_op
Fourier operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData` object,
the Fourier operator is automatically created from the data.
This operator can already include the coil sensitivity weighting, if gradients wrt the coil sensitivity maps
NOT required. Otherwise, they should be given as an additional argument.
csm
Coil sensitivity maps. If None, no coil sensitivity weighting is applied.

Returns
-------
Updated image estimate.
"""
return super().__call__(image, data, fourier_op, csm)

def forward(
self,
image: torch.Tensor,
data: torch.Tensor | KData,
fourier_op: FourierOp | None = None,
csm: torch.Tensor | CsmData | None = None,
) -> torch.Tensor:
"""Apply the data consistency."""
if fourier_op is None:
if isinstance(data, KData):
fourier_op = FourierOp.from_kdata(data)
else:
raise ValueError('Either a KData or a FourierOp is required')

data_ = data.data if isinstance(data, KData) else data

if csm is None:
csm = torch.tensor(())
elif isinstance(csm, CsmData):
csm = csm.data

def operator_factory(csm: torch.Tensor, weight: torch.Tensor, *_):
op = fourier_op.gram
if csm.numel():
csm_op = SensitivityOp(csm)
op = csm_op.H @ op @ csm_op
op = op + weight
return op

def rhs_factory(_csm: torch.Tensor, weight: torch.Tensor, zero_filled: torch.Tensor, image: torch.Tensor):
return (zero_filled + weight * image,)

cg_op = ConjugateGradientOp(operator_factory=operator_factory, rhs_factory=rhs_factory)
(result,) = cg_op(csm, self.log_weight.exp(), fourier_op.adjoint(data_)[0], image)
return result
99 changes: 99 additions & 0 deletions src/mrpro/nn/data_consistency/GradientDescentDC.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Gradient descent data consistency."""

from typing import overload

import torch
from torch.nn import Module, Parameter

from mrpro.data.KData import KData
from mrpro.operators.FourierOp import FourierOp
from mrpro.operators.LinearOperator import LinearOperator


class GradientDescentDC(Module):
r"""Gradient descent data consistency.

Performs gradient descent steps on
:math:`\|Ax - k\|_2^2` where :math:`A` is the acquisition operator and :math:`k` is the data.

Parameters
----------
initial_stepsize
Initial stepsize. The stepsize is a trainable parameter.
Must be a positive scalar.
n_steps
Number of gradient descent steps.

Returns
-------
The updated image.
"""

def __init__(self, initial_stepsize: float | torch.Tensor, n_steps: int = 1) -> None:
"""Initialize the gradient descent data consistency.

Parameters
----------
initial_stepsize
Initial stepsize. The stepsize is a trainable parameter.
Must be a positive scalar.
n_steps
Number of gradient descent steps.
"""
super().__init__()
stepsize = torch.as_tensor(initial_stepsize)
if stepsize.ndim != 0:
raise ValueError('Stepsize must be a scalar')
if stepsize.item() <= 0:
raise ValueError('Stepsize must be positive')
self.log_stepsize = Parameter(stepsize.log())
self.n_steps = n_steps

@overload
def __call__(
self, image: torch.Tensor, data: KData, acquisition_operator: LinearOperator | None = None
) -> torch.Tensor: ...

@overload
def __call__(
self, image: torch.Tensor, data: torch.Tensor, acquisition_operator: LinearOperator
) -> torch.Tensor: ...

def __call__(
self, image: torch.Tensor, data: KData | torch.Tensor, acquisition_operator: LinearOperator | None = None
) -> torch.Tensor:
"""Apply the data consistency.

Parameters
----------
image
Current image estimate.
data
k-space data.
acquisition_operator
Acquisition operator matching the k-space data. If None and data is provided as a `~mrpro.data.KData`
object, the Fourier operator is automatically created from the data.

Returns
-------
Updated image estimate.
"""
return super().__call__(image, data, acquisition_operator)

def forward(
self, image: torch.Tensor, data: KData | torch.Tensor, acquisition_operator: LinearOperator | None = None
) -> torch.Tensor:
"""Apply the data consistency."""
if acquisition_operator is None:
if isinstance(data, KData):
acquisition_operator = FourierOp.from_kdata(data)
else:
raise ValueError('Either a KData or an acquisition operator is required')

data_ = data.data if isinstance(data, KData) else data
stepsize = self.log_stepsize.exp()
x = image
for _ in range(self.n_steps):
residual = acquisition_operator(x)[0] - data_
x = x - stepsize * acquisition_operator.adjoint(residual)[0]
return x
5 changes: 5 additions & 0 deletions src/mrpro/nn/data_consistency/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from mrpro.nn.data_consistency.AnalyticCartesianDC import AnalyticCartesianDC
from mrpro.nn.data_consistency.GradientDescentDC import GradientDescentDC
from mrpro.nn.data_consistency.ConjugateGradientDC import ConjugateGradientDC

__all__ = ["AnalyticCartesianDC", "ConjugateGradientDC", "GradientDescentDC"]
46 changes: 46 additions & 0 deletions tests/nn/data_consistency/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Test fixtures for data consistency tests."""

import pytest
from mrpro.data.KData import KData
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.data.traj_calculators.KTrajectoryCartesian import KTrajectoryCartesian
from mrpro.operators.FourierOp import FourierOp
from mrpro.phantoms.EllipsePhantom import EllipsePhantom
from mrpro.utils import RandomGenerator


@pytest.fixture
def kdata():
matrix = SpatialDimension(x=128, y=128, z=1)
kdata = EllipsePhantom().kdata(KTrajectoryCartesian.fullysampled(matrix), matrix)
return kdata


@pytest.fixture
def kdata_noisy(kdata: KData):
kdata_noisy = kdata.clone()
kdata_noisy.data += 0.1 * RandomGenerator(123).randn_like(kdata_noisy.data)
return kdata_noisy


@pytest.fixture
def kdata_us(kdata: KData):
return kdata[..., ::2, :].clone()


@pytest.fixture
def image_noisy(kdata_noisy: KData):
fourier_op = FourierOp.from_kdata(kdata_noisy)
return fourier_op.adjoint(kdata_noisy.data)[0]


@pytest.fixture
def image(kdata: KData):
fourier_op = FourierOp.from_kdata(kdata)
return fourier_op.adjoint(kdata.data)[0]


@pytest.fixture
def image_us(kdata_us: KData):
fourier_op = FourierOp.from_kdata(kdata_us)
return fourier_op.adjoint(kdata_us.data)[0]
Loading
Loading