diff --git a/src/mrpro/algorithms/csm/inati.py b/src/mrpro/algorithms/csm/inati.py index f202fb629..935a459cf 100644 --- a/src/mrpro/algorithms/csm/inati.py +++ b/src/mrpro/algorithms/csm/inati.py @@ -4,67 +4,97 @@ from einops import einsum from mrpro.data.SpatialDimension import SpatialDimension -from mrpro.utils.sliding_window import sliding_window +from mrpro.utils.filters import uniform_filter def inati( coil_img: torch.Tensor, smoothing_width: SpatialDimension[int] | int, + n_iterations: int = 10, ) -> torch.Tensor: - """Calculate a coil sensitivity map (csm) using the Inati method [INA2013]_ [INA2014]_. + """Calculate a coil sensitivity map (csm) using the iterative Inati method [INA2014]_. - This is for a single set of coil images. The input should be a tensor with dimensions `(coils, z, y, x)`. The output - will have the same dimensions. Either apply this function individually to each set of coil images, or see - `~mrpro.data.CsmData.from_idata_inati` which performs this operation on a whole dataset. + This function computes CSMs using an iterative alternating minimization approach that estimates + the combined structural image and the coil sensitivity profiles simultaneously. Compared to + covariance-based methods, this approach is more memory-efficient and inherently enforces + spatial phase coherence. - .. [INA2013] Inati S, Hansen M, Kellman P (2013) A solution to the phase problem in adaptvie coil combination. - in Proceedings of the 21st Annual Meeting of ISMRM, Salt Lake City, USA, 2672. + The algorithm follows these steps: + + 1. **Initialize Combined Image**: + Create an initial coil combination using a global sum of the coil data. + + 2. **Iterative Refinement**: + For a specified number of iterations: + a. Update the CSMs by dividing the coil images by the combined image and smoothing the results. + b. Update the combined image using the newly estimated CSMs (SENSE combination). + c. Align the global phase at each step to avoid phase singularities and ensure spatial smoothness. + + 3. **Final Normalization**: + Normalize the sensitivity maps to have unit 2-norm across the coil dimension. - .. [INA2014] Inati S, Hansen M (2014) A Fast Optimal Method for Coil Sensitivity Estimation and Adaptive Coil - Combination for Complex images. in Proceedings of Joint Annual Meeting ISMRM-ESMRMB, Milan, Italy, 7115. + This function supports one or more sets of coil images with leading dimensions. The input + should be a tensor with dimensions `(..., coils, z, y, x)`. Prefer using + `~mrpro.data.CsmData` when sensitivity estimation should stay synchronized with MRpro data + containers and metadata. Parameters ---------- coil_img - images for each coil element + Images for each coil element, shape `(..., coils, z, y, x)`. smoothing_width - Size of the smoothing kernel + size of the smoothing kernel. + n_iterations + number of iterations to refine the maps. + + Returns + ------- + csm + Coil sensitivity map, shape `(..., coils, z, y, x)`. + + References + ---------- + .. [INA2013] Inati S, Hansen M, Kellman P (2013) A solution to the phase problem in adaptive coil combination. + in Proceedings of the 21st Annual Meeting of ISMRM, Salt Lake City, USA, 2672. + .. [INA2014] Inati S, Hansen M, Kellman P (2014) A Fast Optimal Method for Coil Sensitivity Estimation and + Adaptive Coil Combination for Complex Images. in Proceedings of Joint Annual Meeting ISMRM-ESMRMB, Milan, + Italy, 7115. """ - # After 10 power iterations we will have a very good estimate of the singular vector - n_power_iterations = 10 - eps = 1e-8 # for numerical stability + eps = 1e-12 + + if n_iterations < 1: + raise ValueError(f'n_iterations must be at least 1, got {n_iterations}') if isinstance(smoothing_width, int): smoothing_width = SpatialDimension( - z=smoothing_width if coil_img.shape[-3] > 1 else 1, y=smoothing_width, x=smoothing_width + z=smoothing_width if coil_img.shape[-3] > 1 else 1, + y=smoothing_width, + x=smoothing_width, ) if any(ks % 2 != 1 for ks in [smoothing_width.z, smoothing_width.y, smoothing_width.x]): raise ValueError('kernel_size must be odd') - ks_halved = [ks // 2 for ks in smoothing_width.zyx] - padded_coil_img = torch.nn.functional.pad( - coil_img, - (ks_halved[-1], ks_halved[-1], ks_halved[-2], ks_halved[-2], ks_halved[-3], ks_halved[-3]), - mode='replicate', - ) - # Get the voxels in an ROI defined by the smoothing_width around each voxel leading to shape - # (z y x coils window=prod(smoothing_width)) - coil_img_roi = sliding_window(padded_coil_img, smoothing_width.zyx, dim=(-3, -2, -1)).flatten(-3) - coil_img_cov = einsum( - coil_img_roi.conj(), - coil_img_roi, - '... coils1 window,... coils2 window->... coils1 coils2', - ) - - singular_vector = torch.sum(coil_img_roi, dim=-1) # z y x coils - singular_vector /= singular_vector.norm(dim=-1, keepdim=True) + eps - for _ in range(n_power_iterations): - singular_vector = einsum(coil_img_cov, singular_vector, '... coils1 coils2,... coils2->... coils1') - singular_vector /= singular_vector.norm(dim=-1, keepdim=True) + eps - - singular_value = einsum(coil_img_roi, singular_vector, '... coils window,... coils->... window') - phase = singular_value.sum(-1) - phase /= phase.abs() + eps - csm = einsum(singular_vector.conj(), phase, '... coils,...->coils ...') # coils z y x + # Initial guess for combined image phase + d_sum = torch.sum(coil_img, dim=(-3, -2, -1)) + d_sum /= d_sum.norm(dim=-1, keepdim=True) + eps + combined_img = einsum(d_sum.conj(), coil_img, '... c, ... c z y x -> ... z y x') + + for _ in range(n_iterations): + # Update CSM + csm = coil_img * combined_img.conj().unsqueeze(-4) + csm = uniform_filter(csm, width=smoothing_width.zyx, dim=(-3, -2, -1)) + csm /= csm.norm(dim=-4, keepdim=True) + eps + + # Update Combined Image + combined_img = einsum(coil_img, csm.conj(), '... c z y x, ... c z y x -> ... z y x') + + # Compute global phase reference and align + d_sum = (csm * combined_img.unsqueeze(-4)).sum(dim=(-3, -2, -1)) + d_sum /= d_sum.norm(dim=-1, keepdim=True) + eps + + phase = einsum(d_sum.conj(), csm, '... c, ... c z y x -> ... z y x').angle() + combined_img *= torch.exp(1j * phase) + csm *= torch.exp(-1j * phase).unsqueeze(-4) + return csm diff --git a/src/mrpro/algorithms/csm/walsh.py b/src/mrpro/algorithms/csm/walsh.py index 387257334..685358f19 100644 --- a/src/mrpro/algorithms/csm/walsh.py +++ b/src/mrpro/algorithms/csm/walsh.py @@ -1,4 +1,4 @@ -"""(Iterative) Walsh method for coil sensitivity map calculation.""" +"""Walsh method for coil sensitivity map calculation.""" import torch @@ -6,8 +6,12 @@ from mrpro.utils.filters import uniform_filter -def walsh(coil_images: torch.Tensor, smoothing_width: SpatialDimension[int] | int) -> torch.Tensor: - """Calculate a coil sensitivity map (csm) using an iterative version of the Walsh method. +def walsh( + coil_images: torch.Tensor, + smoothing_width: SpatialDimension[int] | int, + align_phase: bool = True, +) -> torch.Tensor: + """Calculate a coil sensitivity map (csm) using Walsh's method [WAL2000]_. This function computes CSMs from a set of complex coil images assuming spatially slowly changing sensitivity maps using Walsh's method [WAL2000]_. @@ -28,44 +32,64 @@ def walsh(coil_images: torch.Tensor, smoothing_width: SpatialDimension[int] | in 4. **Normalize Sensitivity Maps**: Normalize the resulting eigenvectors to produce the final CSMs. - This function works on a single set of coil images. The input should be a tensor with dimensions - `(coils, z, y, x)`. The output will have the same dimensions. Either apply this function individually to each set of - coil images, or see `~mrpro.data.CsmData.from_idata_walsh` which performs this operation on a whole dataset - [WAL2000]_. + 5. **Phase Alignment (Optional)**: + If `align_phase` is True, aligns the eigenvectors' global phase to a reference derived from the + coil data [INA2013]_. This prevents phase singularities that otherwise cause destructive + interference when spatially interpolating or downsampling the maps. - This implementation is inspired by `ismrmrd-python-tools `_. + This function supports one or more sets of coil images with leading dimensions. The input + should be a tensor with dimensions `(..., coils, z, y, x)`. Prefer using + `~mrpro.data.CsmData` when sensitivity estimation should stay synchronized with MRpro data + containers and metadata. Parameters ---------- coil_images - images for each coil element + Images for each coil element, shape `(..., coils, z, y, x)`. smoothing_width - width of the smoothing filter + width of the smoothing filter. + align_phase + if True, resolve the phase ambiguity of eigenvectors relative to the data [INA2013]_. + + Returns + ------- + csm + Coil sensitivity map, shape `(..., coils, z, y, x)`. References ---------- .. [WAL2000] Walsh DO, Gmitro AF, Marcellin MW (2000) Adaptive reconstruction of phased array MR imagery. MRM 43 + .. [INA2013] Inati S, Hansen M, Kellman P (2013) A solution to the phase problem in adaptive coil combination. + in Proceedings of the 21st Annual Meeting of ISMRM, Salt Lake City, USA, 2672. """ - # After 10 power iterations we will have a very good estimate of the singular vector n_power_iterations = 10 + eps = 1e-12 if isinstance(smoothing_width, int): - smoothing_width = SpatialDimension(smoothing_width, smoothing_width, smoothing_width) - # Compute the pointwise covariance between coils - coil_covariance = torch.einsum('azyx,bzyx->abzyx', coil_images, coil_images.conj()) + smoothing_width = SpatialDimension( + z=smoothing_width if coil_images.shape[-3] > 1 else 1, y=smoothing_width, x=smoothing_width + ) + + # Pointwise covariance + coil_covariance = torch.einsum('... a z y x, ... b z y x -> ... a b z y x', coil_images, coil_images.conj()) - # Smooth the covariance along y-x for 2D and z-y-x for 3D data + # Smooth covariance coil_covariance = uniform_filter(coil_covariance, width=smoothing_width.zyx, dim=(-3, -2, -1)) - # At each point in the image, find the dominant eigenvector - # of the signal covariance matrix using the power method - v = coil_covariance.sum(dim=0) + # Power iterations for dominant eigenvector + v = coil_covariance.sum(dim=-4) for _ in range(n_power_iterations): - v /= v.norm(dim=0) - v = torch.einsum('abzyx,bzyx->azyx', coil_covariance, v) - csm = v / v.norm(dim=0) + v = v / (v.norm(dim=-4, keepdim=True) + eps) + v = torch.einsum('... a b z y x, ... b z y x -> ... a z y x', coil_covariance, v) + + csm = v / (v.norm(dim=-4, keepdim=True) + eps) + + if align_phase: + # Resolve global phase ambiguity using a low-res data projection + d_sum = torch.sum(coil_images, dim=(-3, -2, -1), keepdim=True) + d_sum /= d_sum.norm(dim=-4, keepdim=True) + eps + phase_map = torch.einsum('... c z y x, ... c z y x -> ... z y x', d_sum.conj(), csm).angle() + csm = csm * torch.exp(-1j * phase_map).unsqueeze(-4) - # Make sure there are no inf or nan-values due to very small values in the covariance matrix - # nan_to_num does not work for complexfloat, boolean indexing not with vmap. csm = torch.where(torch.isfinite(csm), csm, 0.0) return csm diff --git a/src/mrpro/data/CsmData.py b/src/mrpro/data/CsmData.py index 8dbdf3fe3..04af92921 100644 --- a/src/mrpro/data/CsmData.py +++ b/src/mrpro/data/CsmData.py @@ -55,6 +55,7 @@ def from_kdata_walsh( kdata: KData, noise: KNoise | None = None, smoothing_width: int | SpatialDimension[int] = 5, + align_phase: bool = True, chunk_size_otherdim: int | None = None, downsampled_size: int | SpatialDimension[int] | None = None, ) -> Self: @@ -69,7 +70,10 @@ def from_kdata_walsh( noise, optional Noise measurement for prewhitening. smoothing_width - width of smoothing filter + Width of smoothing filter. + align_phase + If `True`, resolve the global phase ambiguity using the phase alignment + described in [INA2013]_. chunk_size_otherdim How many elements of the other dimensions should be processed at once. Default is `None`, which means that all elements are processed at once. @@ -87,6 +91,7 @@ def from_kdata_walsh( return cls.from_idata_walsh( DirectReconstruction(kdata, noise=noise, csm=None)(kdata), smoothing_width, + align_phase, chunk_size_otherdim, downsampled_size, ) @@ -96,6 +101,7 @@ def from_idata_walsh( cls, idata: IData, smoothing_width: int | SpatialDimension[int] = 5, + align_phase: bool = True, chunk_size_otherdim: int | None = None, downsampled_size: int | SpatialDimension[int] | None = None, ) -> Self: @@ -108,7 +114,10 @@ def from_idata_walsh( idata IData object containing the images for each coil element. smoothing_width - width of smoothing filter. + Width of smoothing filter. + align_phase + If `True`, resolve the global phase ambiguity using the phase alignment + described in [INA2013]_. chunk_size_otherdim: How many elements of the other dimensions should be processed at once. Default is `None`, which means that all elements are processed at once. @@ -126,7 +135,7 @@ def from_idata_walsh( csm_fun = torch.vmap( lambda img: apply_lowres( - lambda x: walsh(x, smoothing_width), + lambda x: walsh(x, smoothing_width, align_phase=align_phase), size=get_downsampled_size(idata.data.shape, downsampled_size), dim=(-3, -2, -1), )(img), @@ -144,10 +153,11 @@ def from_kdata_inati( kdata: KData, noise: KNoise | None = None, smoothing_width: int | SpatialDimension[int] = 5, + n_iterations: int = 5, chunk_size_otherdim: int | None = None, downsampled_size: int | SpatialDimension[int] | None = None, ) -> Self: - """Create csm object from k-space data using Inati method. + """Create csm object from k-space data using the iterative Inati method. See also `~mrpro.algorithms.csm.inati`. @@ -158,7 +168,10 @@ def from_kdata_inati( noise, optional Noise measurement for prewhitening. smoothing_width - width of smoothing filter + Width of smoothing filter. + n_iterations + Number of Inati iterations used to refine the combined image and the + coil sensitivity maps. chunk_size_otherdim How many elements of the other dimensions should be processed at once. Default is `None`, which means that all elements are processed at once. @@ -176,6 +189,7 @@ def from_kdata_inati( return cls.from_idata_inati( DirectReconstruction(kdata, noise=noise, csm=None)(kdata), smoothing_width, + n_iterations, chunk_size_otherdim, downsampled_size, ) @@ -185,10 +199,11 @@ def from_idata_inati( cls, idata: IData, smoothing_width: int | SpatialDimension[int] = 5, + n_iterations: int = 5, chunk_size_otherdim: int | None = None, downsampled_size: int | SpatialDimension[int] | None = None, ) -> Self: - """Create csm object from image data using Inati method. + """Create csm object from image data using the iterative Inati method. See also `~mrpro.algorithms.csm.inati`. @@ -198,6 +213,9 @@ def from_idata_inati( IData object containing the images for each coil element. smoothing_width Size of the smoothing kernel. + n_iterations + Number of Inati iterations used to refine the combined image and the + coil sensitivity maps. chunk_size_otherdim: How many elements of the other dimensions should be processed at once. Default is `None`, which means that all elements are processed at once. @@ -214,7 +232,7 @@ def from_idata_inati( csm_fun = torch.vmap( lambda img: apply_lowres( - lambda x: inati(x, smoothing_width), + lambda x: inati(x, smoothing_width, n_iterations=n_iterations), size=get_downsampled_size(idata.data.shape, downsampled_size), dim=(-3, -2, -1), )(img), diff --git a/tests/algorithms/csm/test_inati.py b/tests/algorithms/csm/test_inati.py index beaa2fa5d..4acd7eae3 100644 --- a/tests/algorithms/csm/test_inati.py +++ b/tests/algorithms/csm/test_inati.py @@ -1,5 +1,6 @@ """Tests the iterative Walsh algorithm.""" +import pytest import torch from mrpro.algorithms.csm import inati from mrpro.data import SpatialDimension @@ -18,3 +19,11 @@ def test_inati(ellipse_phantom, random_kheader): # Phase is only relative in csm calculation, therefore only the abs values are compared. assert relative_image_difference(torch.abs(csm), torch.abs(csm_ref[0, ...])) <= 0.01 + + +def test_inati_requires_positive_iterations() -> None: + """Test that Inati validates the number of iterations.""" + coil_img = torch.ones(2, 1, 4, 4, dtype=torch.complex64) + + with pytest.raises(ValueError, match='n_iterations must be at least 1'): + inati(coil_img, smoothing_width=1, n_iterations=0) diff --git a/tests/data/test_csm_data.py b/tests/data/test_csm_data.py index c8a6acdb1..6ad5a93c4 100644 --- a/tests/data/test_csm_data.py +++ b/tests/data/test_csm_data.py @@ -87,3 +87,17 @@ def test_CsmData_kdata_inati(ismrmrd_cart_single_rep) -> None: csm_from_kdata = CsmData.from_kdata_inati(kdata) csm_from_idata = CsmData.from_idata_inati(idata) torch.testing.assert_close(csm_from_kdata.data, csm_from_idata.data, rtol=1e-5, atol=1e-5) + + +def test_CsmData_walsh_align_phase_matches_explicit_call( + ellipse_phantom: EllipsePhantomTestData, random_kheader: KHeader +) -> None: + """CsmData Walsh uses phase alignment by default.""" + idata, _ = multi_coil_image(n_coils=4, ph_ellipse=ellipse_phantom, random_kheader=random_kheader) + + csm_default = CsmData.from_idata_walsh(idata) + csm_no_phase = CsmData.from_idata_walsh(idata, align_phase=False) + csm_with_phase = CsmData.from_idata_walsh(idata, align_phase=True) + + torch.testing.assert_close(csm_default.data, csm_with_phase.data) + assert not torch.allclose(csm_default.data, csm_no_phase.data)