Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 70 additions & 40 deletions src/mrpro/algorithms/csm/inati.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,67 +4,97 @@
from einops import einsum

from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.utils.sliding_window import sliding_window
from mrpro.utils.filters import uniform_filter


def inati(
coil_img: torch.Tensor,
smoothing_width: SpatialDimension[int] | int,
n_iterations: int = 10,
) -> torch.Tensor:
"""Calculate a coil sensitivity map (csm) using the Inati method [INA2013]_ [INA2014]_.
"""Calculate a coil sensitivity map (csm) using the iterative Inati method [INA2014]_.

This is for a single set of coil images. The input should be a tensor with dimensions `(coils, z, y, x)`. The output
will have the same dimensions. Either apply this function individually to each set of coil images, or see
`~mrpro.data.CsmData.from_idata_inati` which performs this operation on a whole dataset.
This function computes CSMs using an iterative alternating minimization approach that estimates
the combined structural image and the coil sensitivity profiles simultaneously. Compared to
covariance-based methods, this approach is more memory-efficient and inherently enforces
spatial phase coherence.

.. [INA2013] Inati S, Hansen M, Kellman P (2013) A solution to the phase problem in adaptvie coil combination.
in Proceedings of the 21st Annual Meeting of ISMRM, Salt Lake City, USA, 2672.
The algorithm follows these steps:

1. **Initialize Combined Image**:
Create an initial coil combination using a global sum of the coil data.

2. **Iterative Refinement**:
For a specified number of iterations:
a. Update the CSMs by dividing the coil images by the combined image and smoothing the results.
b. Update the combined image using the newly estimated CSMs (SENSE combination).
c. Align the global phase at each step to avoid phase singularities and ensure spatial smoothness.

3. **Final Normalization**:
Normalize the sensitivity maps to have unit 2-norm across the coil dimension.

.. [INA2014] Inati S, Hansen M (2014) A Fast Optimal Method for Coil Sensitivity Estimation and Adaptive Coil
Combination for Complex images. in Proceedings of Joint Annual Meeting ISMRM-ESMRMB, Milan, Italy, 7115.
This function supports one or more sets of coil images with leading dimensions. The input
should be a tensor with dimensions `(..., coils, z, y, x)`. Prefer using
`~mrpro.data.CsmData` when sensitivity estimation should stay synchronized with MRpro data
containers and metadata.

Parameters
----------
coil_img
images for each coil element
Images for each coil element, shape `(..., coils, z, y, x)`.
smoothing_width
Size of the smoothing kernel
size of the smoothing kernel.
n_iterations
number of iterations to refine the maps.

Returns
-------
csm
Coil sensitivity map, shape `(..., coils, z, y, x)`.

References
----------
.. [INA2013] Inati S, Hansen M, Kellman P (2013) A solution to the phase problem in adaptive coil combination.
in Proceedings of the 21st Annual Meeting of ISMRM, Salt Lake City, USA, 2672.
.. [INA2014] Inati S, Hansen M, Kellman P (2014) A Fast Optimal Method for Coil Sensitivity Estimation and
Adaptive Coil Combination for Complex Images. in Proceedings of Joint Annual Meeting ISMRM-ESMRMB, Milan,
Italy, 7115.
"""
# After 10 power iterations we will have a very good estimate of the singular vector
n_power_iterations = 10
eps = 1e-8 # for numerical stability
eps = 1e-12

if n_iterations < 1:
raise ValueError(f'n_iterations must be at least 1, got {n_iterations}')

if isinstance(smoothing_width, int):
smoothing_width = SpatialDimension(
z=smoothing_width if coil_img.shape[-3] > 1 else 1, y=smoothing_width, x=smoothing_width
z=smoothing_width if coil_img.shape[-3] > 1 else 1,
y=smoothing_width,
x=smoothing_width,
)

if any(ks % 2 != 1 for ks in [smoothing_width.z, smoothing_width.y, smoothing_width.x]):
raise ValueError('kernel_size must be odd')

ks_halved = [ks // 2 for ks in smoothing_width.zyx]
padded_coil_img = torch.nn.functional.pad(
coil_img,
(ks_halved[-1], ks_halved[-1], ks_halved[-2], ks_halved[-2], ks_halved[-3], ks_halved[-3]),
mode='replicate',
)
# Get the voxels in an ROI defined by the smoothing_width around each voxel leading to shape
# (z y x coils window=prod(smoothing_width))
coil_img_roi = sliding_window(padded_coil_img, smoothing_width.zyx, dim=(-3, -2, -1)).flatten(-3)
coil_img_cov = einsum(
coil_img_roi.conj(),
coil_img_roi,
'... coils1 window,... coils2 window->... coils1 coils2',
)

singular_vector = torch.sum(coil_img_roi, dim=-1) # z y x coils
singular_vector /= singular_vector.norm(dim=-1, keepdim=True) + eps
for _ in range(n_power_iterations):
singular_vector = einsum(coil_img_cov, singular_vector, '... coils1 coils2,... coils2->... coils1')
singular_vector /= singular_vector.norm(dim=-1, keepdim=True) + eps

singular_value = einsum(coil_img_roi, singular_vector, '... coils window,... coils->... window')
phase = singular_value.sum(-1)
phase /= phase.abs() + eps
csm = einsum(singular_vector.conj(), phase, '... coils,...->coils ...') # coils z y x
# Initial guess for combined image phase
d_sum = torch.sum(coil_img, dim=(-3, -2, -1))
d_sum /= d_sum.norm(dim=-1, keepdim=True) + eps
combined_img = einsum(d_sum.conj(), coil_img, '... c, ... c z y x -> ... z y x')

for _ in range(n_iterations):
# Update CSM
csm = coil_img * combined_img.conj().unsqueeze(-4)
csm = uniform_filter(csm, width=smoothing_width.zyx, dim=(-3, -2, -1))
csm /= csm.norm(dim=-4, keepdim=True) + eps

# Update Combined Image
combined_img = einsum(coil_img, csm.conj(), '... c z y x, ... c z y x -> ... z y x')

# Compute global phase reference and align
d_sum = (csm * combined_img.unsqueeze(-4)).sum(dim=(-3, -2, -1))
d_sum /= d_sum.norm(dim=-1, keepdim=True) + eps

phase = einsum(d_sum.conj(), csm, '... c, ... c z y x -> ... z y x').angle()
combined_img *= torch.exp(1j * phase)
csm *= torch.exp(-1j * phase).unsqueeze(-4)

return csm
70 changes: 47 additions & 23 deletions src/mrpro/algorithms/csm/walsh.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""(Iterative) Walsh method for coil sensitivity map calculation."""
"""Walsh method for coil sensitivity map calculation."""

import torch

from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.utils.filters import uniform_filter


def walsh(coil_images: torch.Tensor, smoothing_width: SpatialDimension[int] | int) -> torch.Tensor:
"""Calculate a coil sensitivity map (csm) using an iterative version of the Walsh method.
def walsh(
coil_images: torch.Tensor,
smoothing_width: SpatialDimension[int] | int,
align_phase: bool = True,
) -> torch.Tensor:
"""Calculate a coil sensitivity map (csm) using Walsh's method [WAL2000]_.

This function computes CSMs from a set of complex coil images assuming spatially
slowly changing sensitivity maps using Walsh's method [WAL2000]_.
Expand All @@ -28,44 +32,64 @@ def walsh(coil_images: torch.Tensor, smoothing_width: SpatialDimension[int] | in
4. **Normalize Sensitivity Maps**:
Normalize the resulting eigenvectors to produce the final CSMs.

This function works on a single set of coil images. The input should be a tensor with dimensions
`(coils, z, y, x)`. The output will have the same dimensions. Either apply this function individually to each set of
coil images, or see `~mrpro.data.CsmData.from_idata_walsh` which performs this operation on a whole dataset
[WAL2000]_.
5. **Phase Alignment (Optional)**:
If `align_phase` is True, aligns the eigenvectors' global phase to a reference derived from the
coil data [INA2013]_. This prevents phase singularities that otherwise cause destructive
interference when spatially interpolating or downsampling the maps.

This implementation is inspired by `ismrmrd-python-tools <https://github.com/ismrmrd/ismrmrd-python-tools>`_.
This function supports one or more sets of coil images with leading dimensions. The input
should be a tensor with dimensions `(..., coils, z, y, x)`. Prefer using
`~mrpro.data.CsmData` when sensitivity estimation should stay synchronized with MRpro data
containers and metadata.

Parameters
----------
coil_images
images for each coil element
Images for each coil element, shape `(..., coils, z, y, x)`.
smoothing_width
width of the smoothing filter
width of the smoothing filter.
align_phase
if True, resolve the phase ambiguity of eigenvectors relative to the data [INA2013]_.

Returns
-------
csm
Coil sensitivity map, shape `(..., coils, z, y, x)`.

References
----------
.. [WAL2000] Walsh DO, Gmitro AF, Marcellin MW (2000) Adaptive reconstruction of phased array MR imagery. MRM 43
.. [INA2013] Inati S, Hansen M, Kellman P (2013) A solution to the phase problem in adaptive coil combination.
in Proceedings of the 21st Annual Meeting of ISMRM, Salt Lake City, USA, 2672.
"""
# After 10 power iterations we will have a very good estimate of the singular vector
n_power_iterations = 10
eps = 1e-12

if isinstance(smoothing_width, int):
smoothing_width = SpatialDimension(smoothing_width, smoothing_width, smoothing_width)
# Compute the pointwise covariance between coils
coil_covariance = torch.einsum('azyx,bzyx->abzyx', coil_images, coil_images.conj())
smoothing_width = SpatialDimension(
z=smoothing_width if coil_images.shape[-3] > 1 else 1, y=smoothing_width, x=smoothing_width
)

# Pointwise covariance
coil_covariance = torch.einsum('... a z y x, ... b z y x -> ... a b z y x', coil_images, coil_images.conj())

# Smooth the covariance along y-x for 2D and z-y-x for 3D data
# Smooth covariance
coil_covariance = uniform_filter(coil_covariance, width=smoothing_width.zyx, dim=(-3, -2, -1))

# At each point in the image, find the dominant eigenvector
# of the signal covariance matrix using the power method
v = coil_covariance.sum(dim=0)
# Power iterations for dominant eigenvector
v = coil_covariance.sum(dim=-4)
for _ in range(n_power_iterations):
v /= v.norm(dim=0)
v = torch.einsum('abzyx,bzyx->azyx', coil_covariance, v)
csm = v / v.norm(dim=0)
v = v / (v.norm(dim=-4, keepdim=True) + eps)
v = torch.einsum('... a b z y x, ... b z y x -> ... a z y x', coil_covariance, v)

csm = v / (v.norm(dim=-4, keepdim=True) + eps)

if align_phase:
# Resolve global phase ambiguity using a low-res data projection
d_sum = torch.sum(coil_images, dim=(-3, -2, -1), keepdim=True)
d_sum /= d_sum.norm(dim=-4, keepdim=True) + eps
phase_map = torch.einsum('... c z y x, ... c z y x -> ... z y x', d_sum.conj(), csm).angle()
csm = csm * torch.exp(-1j * phase_map).unsqueeze(-4)

# Make sure there are no inf or nan-values due to very small values in the covariance matrix
# nan_to_num does not work for complexfloat, boolean indexing not with vmap.
csm = torch.where(torch.isfinite(csm), csm, 0.0)
return csm
32 changes: 25 additions & 7 deletions src/mrpro/data/CsmData.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def from_kdata_walsh(
kdata: KData,
noise: KNoise | None = None,
smoothing_width: int | SpatialDimension[int] = 5,
align_phase: bool = True,
chunk_size_otherdim: int | None = None,
downsampled_size: int | SpatialDimension[int] | None = None,
) -> Self:
Expand All @@ -69,7 +70,10 @@ def from_kdata_walsh(
noise, optional
Noise measurement for prewhitening.
smoothing_width
width of smoothing filter
Width of smoothing filter.
align_phase
If `True`, resolve the global phase ambiguity using the phase alignment
described in [INA2013]_.
chunk_size_otherdim
How many elements of the other dimensions should be processed at once.
Default is `None`, which means that all elements are processed at once.
Expand All @@ -87,6 +91,7 @@ def from_kdata_walsh(
return cls.from_idata_walsh(
DirectReconstruction(kdata, noise=noise, csm=None)(kdata),
smoothing_width,
align_phase,
chunk_size_otherdim,
downsampled_size,
)
Expand All @@ -96,6 +101,7 @@ def from_idata_walsh(
cls,
idata: IData,
smoothing_width: int | SpatialDimension[int] = 5,
align_phase: bool = True,
chunk_size_otherdim: int | None = None,
downsampled_size: int | SpatialDimension[int] | None = None,
) -> Self:
Expand All @@ -108,7 +114,10 @@ def from_idata_walsh(
idata
IData object containing the images for each coil element.
smoothing_width
width of smoothing filter.
Width of smoothing filter.
align_phase
If `True`, resolve the global phase ambiguity using the phase alignment
described in [INA2013]_.
chunk_size_otherdim:
How many elements of the other dimensions should be processed at once.
Default is `None`, which means that all elements are processed at once.
Expand All @@ -126,7 +135,7 @@ def from_idata_walsh(

csm_fun = torch.vmap(
lambda img: apply_lowres(
lambda x: walsh(x, smoothing_width),
lambda x: walsh(x, smoothing_width, align_phase=align_phase),
size=get_downsampled_size(idata.data.shape, downsampled_size),
dim=(-3, -2, -1),
)(img),
Expand All @@ -144,10 +153,11 @@ def from_kdata_inati(
kdata: KData,
noise: KNoise | None = None,
smoothing_width: int | SpatialDimension[int] = 5,
n_iterations: int = 5,
chunk_size_otherdim: int | None = None,
downsampled_size: int | SpatialDimension[int] | None = None,
) -> Self:
"""Create csm object from k-space data using Inati method.
"""Create csm object from k-space data using the iterative Inati method.

See also `~mrpro.algorithms.csm.inati`.

Expand All @@ -158,7 +168,10 @@ def from_kdata_inati(
noise, optional
Noise measurement for prewhitening.
smoothing_width
width of smoothing filter
Width of smoothing filter.
n_iterations
Number of Inati iterations used to refine the combined image and the
coil sensitivity maps.
chunk_size_otherdim
How many elements of the other dimensions should be processed at once.
Default is `None`, which means that all elements are processed at once.
Expand All @@ -176,6 +189,7 @@ def from_kdata_inati(
return cls.from_idata_inati(
DirectReconstruction(kdata, noise=noise, csm=None)(kdata),
smoothing_width,
n_iterations,
chunk_size_otherdim,
downsampled_size,
)
Expand All @@ -185,10 +199,11 @@ def from_idata_inati(
cls,
idata: IData,
smoothing_width: int | SpatialDimension[int] = 5,
n_iterations: int = 5,
chunk_size_otherdim: int | None = None,
downsampled_size: int | SpatialDimension[int] | None = None,
) -> Self:
"""Create csm object from image data using Inati method.
"""Create csm object from image data using the iterative Inati method.

See also `~mrpro.algorithms.csm.inati`.

Expand All @@ -198,6 +213,9 @@ def from_idata_inati(
IData object containing the images for each coil element.
smoothing_width
Size of the smoothing kernel.
n_iterations
Number of Inati iterations used to refine the combined image and the
coil sensitivity maps.
chunk_size_otherdim:
How many elements of the other dimensions should be processed at once.
Default is `None`, which means that all elements are processed at once.
Expand All @@ -214,7 +232,7 @@ def from_idata_inati(

csm_fun = torch.vmap(
lambda img: apply_lowres(
lambda x: inati(x, smoothing_width),
lambda x: inati(x, smoothing_width, n_iterations=n_iterations),
size=get_downsampled_size(idata.data.shape, downsampled_size),
dim=(-3, -2, -1),
)(img),
Expand Down
9 changes: 9 additions & 0 deletions tests/algorithms/csm/test_inati.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests the iterative Walsh algorithm."""

import pytest
import torch
from mrpro.algorithms.csm import inati
from mrpro.data import SpatialDimension
Expand All @@ -18,3 +19,11 @@ def test_inati(ellipse_phantom, random_kheader):

# Phase is only relative in csm calculation, therefore only the abs values are compared.
assert relative_image_difference(torch.abs(csm), torch.abs(csm_ref[0, ...])) <= 0.01


def test_inati_requires_positive_iterations() -> None:
"""Test that Inati validates the number of iterations."""
coil_img = torch.ones(2, 1, 4, 4, dtype=torch.complex64)

with pytest.raises(ValueError, match='n_iterations must be at least 1'):
inati(coil_img, smoothing_width=1, n_iterations=0)
Loading
Loading