Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
}
}
38 changes: 22 additions & 16 deletions src/mrpro/operators/FourierOp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Fourier Operator."""

import warnings
from collections.abc import Sequence
from functools import cached_property

Expand Down Expand Up @@ -66,20 +67,6 @@ def get_spatial_dims(spatial_dims: SpatialDimension, dims: Sequence[int]):
else:
self._nufft_dims.append(dim)

if self._fft_dims:
self._fast_fourier_op: FastFourierOp | None = FastFourierOp(
dim=tuple(self._fft_dims),
recon_matrix=get_spatial_dims(recon_matrix, self._fft_dims),
encoding_matrix=get_spatial_dims(encoding_matrix, self._fft_dims),
)
self._cart_sampling_op: CartesianSamplingOp | None = CartesianSamplingOp(
encoding_matrix=encoding_matrix, traj=traj
)
else:
self._fast_fourier_op = None
self._cart_sampling_op = None

# Find dimensions which require NUFFT
if self._nufft_dims:
fft_dims_k210 = [
dim
Expand All @@ -88,10 +75,16 @@ def get_spatial_dims(spatial_dims: SpatialDimension, dims: Sequence[int]):
and not (traj.type_along_k210[dim] & TrajType.SINGLEVALUE)
]
if self._fft_dims != fft_dims_k210:
raise NotImplementedError(
warnings.warn(
'If both FFT and NUFFT dims are present, Cartesian FFT dims need to be aligned with the '
'k-space dimension, i.e. kx along k0, ky along k1 and kz along k2.',
'k-space dimension, i.e. kx along k0, ky along k1 and kz along k2. We are going to use NUFFT '
'for all dimensions. Creating your own FourierOp with the desired combination of FFT and NUFFT '
'dimensions might be more efficient. ',
stacklevel=2,
)
self._nufft_dims.extend(self._fft_dims)
self._nufft_dims.sort()
self._fft_dims = []

self._non_uniform_fast_fourier_op: NonUniformFastFourierOp | None = NonUniformFastFourierOp(
direction=tuple(self._nufft_dims), # type: ignore[arg-type]
Expand All @@ -102,6 +95,19 @@ def get_spatial_dims(spatial_dims: SpatialDimension, dims: Sequence[int]):
else:
self._non_uniform_fast_fourier_op = None

if self._fft_dims:
self._fast_fourier_op: FastFourierOp | None = FastFourierOp(
dim=tuple(self._fft_dims),
recon_matrix=get_spatial_dims(recon_matrix, self._fft_dims),
encoding_matrix=get_spatial_dims(encoding_matrix, self._fft_dims),
)
self._cart_sampling_op: CartesianSamplingOp | None = CartesianSamplingOp(
encoding_matrix=encoding_matrix, traj=traj
)
else:
self._fast_fourier_op = None
self._cart_sampling_op = None

self._trajectory_info = repr(traj)

@classmethod
Expand Down
8 changes: 6 additions & 2 deletions src/mrpro/operators/NonUniformFastFourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from types import EllipsisType
from typing import Literal

import numpy as np
import torch
from pytorch_finufft.functional import finufft_type1, finufft_type2
from typing_extensions import Self
Expand Down Expand Up @@ -60,6 +61,9 @@ def __init__(
# Convert to negative indexing
direction_dict = {'z': -3, 'y': -2, 'x': -1, -3: -3, -2: -2, -1: -1}
self._direction_zyx = tuple(direction_dict[d] for d in direction)
# Directions are assumed to be in the order -3, -2, -1
direction_index = np.argsort(self._direction_zyx)
self._direction_zyx = tuple(self._direction_zyx[i] for i in direction_index)
if len(direction) != len(set(self._direction_zyx)):
raise ValueError(f'Directions must be unique. Normalized directions are {self._direction_zyx}')
if not self._direction_zyx:
Expand Down Expand Up @@ -101,15 +105,15 @@ def __init__(
else:
if (n_recon_matrix := len(recon_matrix)) != (n_nufft_dir := len(self._direction_zyx)):
raise ValueError(f'recon_matrix should have {n_nufft_dir} entries but has {n_recon_matrix}')
im_size = tuple(recon_matrix)
im_size = tuple(recon_matrix[i] for i in direction_index)
assert len(im_size) == 1 or len(im_size) == 2 or len(im_size) == 3 # mypy # noqa: S101

if isinstance(encoding_matrix, SpatialDimension):
k_size = tuple([int(encoding_matrix.zyx[d]) for d in self._direction_zyx])
else:
if (n_enc_matrix := len(encoding_matrix)) != (n_nufft_dir := len(self._direction_zyx)):
raise ValueError(f'encoding_matrix should have {n_nufft_dir} entries but has {n_enc_matrix}')
k_size = tuple(encoding_matrix)
k_size = tuple(encoding_matrix[i] for i in direction_index)

omega_list = [
k * 2 * torch.pi / ks
Expand Down
2 changes: 1 addition & 1 deletion tests/operators/test_fourier_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_fourier_op_not_supported_traj(
int(trajectory.ky.max() - trajectory.ky.min() + 1),
int(trajectory.kx.max() - trajectory.kx.min() + 1),
)
with pytest.raises(NotImplementedError, match='Cartesian FFT dims need to be aligned'):
with pytest.raises(UserWarning, match='Cartesian FFT dims need to be aligned'):
FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory)


Expand Down
16 changes: 16 additions & 0 deletions tests/operators/test_non_uniform_fast_fourier_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,23 @@ def test_non_uniform_fast_fourier_op_directions() -> None:
encoding_matrix=SpatialDimension[int](z=kdata_shape[-3], y=kdata_shape[-2], x=kdata_shape[-1]),
traj=traj,
)

nufft_op_xy = NonUniformFastFourierOp(
direction=('x', 'y'),
recon_matrix=SpatialDimension[int](z=img_shape[-3], y=img_shape[-2], x=img_shape[-1]),
encoding_matrix=SpatialDimension[int](z=kdata_shape[-3], y=kdata_shape[-2], x=kdata_shape[-1]),
traj=traj,
)

nufft_op_xy_matrix_sequence = NonUniformFastFourierOp(
direction=('x', 'y'),
recon_matrix=(img_shape[-1], img_shape[-2]),
encoding_matrix=(kdata_shape[-1], kdata_shape[-2]),
traj=traj,
)
torch.testing.assert_close(nufft_op_12(img)[0], nufft_op_yx(img)[0])
torch.testing.assert_close(nufft_op_12(img)[0], nufft_op_xy(img)[0])
torch.testing.assert_close(nufft_op_12(img)[0], nufft_op_xy_matrix_sequence(img)[0])


def test_non_uniform_fast_fourier_op_error_directions() -> None:
Expand Down
Loading