Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions src/mrpro/operators/FiniteDifferenceOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,29 @@


class FiniteDifferenceOp(LinearOperator):
"""Finite Difference Operator."""
r"""Finite difference operator.

This pointwise operator computes finite differences of a discrete :math:`d`-dimensional tensor ``x``.
Differences are computed along the axes listed in ``dim``
(e.g. ``dim=(-2, -1)`` for the last two axes)
by means of a separable convolution with appropriate filters
(supported modes are ``forward``, ``backward``, and ``central``).
The output is a :math:`(d+1)`-dimensional tensor ``y``
(``y.shape[0] == len(dim)``)
where each ``y[i]`` is the finite difference tensor along the selected axis ``dim[i]``.

For example, the forward finite difference ``nabla(x)`` along a chosen axis ``dim[i]`` can be written as

.. code-block:: python

y[i, *k] = nabla(x)[i, *k] = x[*(k + e_i)] - x[*k]

for every coordinate ``k = (k_1, ..., k_d)`` in the grid.
Here ``e_i = (e_i_1, ..., e_i_d)`` is the unit vector in direction ``dim[i]``,
i.e. ``e_i[j] = 1`` if ``j == dim[i]`` else ``0``.

Boundary handling (e.g. when coordinate ``k + e_i`` is outside the grid) is controlled by ``pad_mode``.
"""

@staticmethod
def finite_difference_kernel(mode: Literal['central', 'forward', 'backward']) -> torch.Tensor:
Expand All @@ -28,30 +50,30 @@ def finite_difference_kernel(mode: Literal['central', 'forward', 'backward']) ->
Raises
------
`ValueError`
If mode is not central, forward, backward or doublecentral
If mode is not forward, backward or central
"""
if mode == 'central':
kernel = torch.tensor((-1, 0, 1)) / 2
elif mode == 'forward':
if mode == 'forward':
kernel = torch.tensor((0, -1, 1))
elif mode == 'backward':
kernel = torch.tensor((-1, 1, 0))
elif mode == 'central':
kernel = torch.tensor((-1, 0, 1)) / 2
else:
raise ValueError(f'mode should be one of (central, forward, backward), not {mode}')
return kernel

def __init__(
self,
dim: Sequence[int],
mode: Literal['central', 'forward', 'backward'] = 'central',
mode: Literal['central', 'forward', 'backward'] = 'forward',
pad_mode: Literal['zeros', 'circular'] = 'zeros',
) -> None:
"""Finite difference operator.

Parameters
----------
dim
Dimension along which finite differences are calculated.
Dimensions along which finite differences are calculated.
mode
Type of finite difference operator
pad_mode
Expand Down
121 changes: 121 additions & 0 deletions src/mrpro/operators/SymmetrizedGradientOp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Class for Symmetrized Gradient Operator."""

from collections.abc import Sequence
from typing import Literal

import torch

from mrpro.operators.FiniteDifferenceOp import FiniteDifferenceOp
from mrpro.operators.LinearOperator import LinearOperator
from mrpro.operators.RearrangeOp import RearrangeOp


class SymmetrizedGradientOp(LinearOperator):
r"""Discrete symmetrized gradient operator using finite differences.

Based on finite differences along the axes listed in ``dim``
(e.g. ``dim=(-2, -1)`` for the last two axes, see :class:`mrpro.operators.FiniteDifferenceOp`),
this pointwise operator computes the symmetrized gradient of a discrete vector field,
i.e. a :math:`(d+1)`-dimensional tensor ``v`` (``v.shape[0] == len(dim)``),
where each ``v[j]`` is a :math:`d`-dimensional vector component.
The output is a :math:`(d+2)`-dimensional tensor ``w``
(``w.shape[0] == w.shape[1] == len(dim)``) where each ``w[i, j]``
contains the symmetric part of the discrete gradient of ``v``,
computed along the axes listed in ``dim``.
Note that ``dim`` must not contain the :math:`0^{\text{th}}` axis.

The symmetrized gradient ``E(v)`` using the finite difference operator ``nabla`` can be written as

.. code-block:: python

w = E(v) = 0.5 * (nabla(v) + nabla(v).transpose(0, 1))

or more explicitly as

.. code-block:: python

w[i, j] = E(v)[i, j] = 0.5 * (nabla(v)[i, j] + nabla(v)[j, i])

for every ``i``, ``j`` in ``[0, ..., len(dim) - 1]``.

Finite difference modes and boundary handling follow :class:`mrpro.operators.FiniteDifferenceOp`.

A common use case of the symmetrized gradient is Total Generalized Variation (TGV) regularization,
with the 2D case (i.e. ``len(dim) == 2``) shown in [TGV]_.

References
----------
.. [TGV] Bredies, K. Recovering piecewise smooth multichannel images by minimization of convex
functionals with total generalized variation penalty. In: Bruhn, A., Pock, T., Tai, X.C. (eds)
Efficient Algorithms for Global Optimization Methods in Computer Vision. Lecture Notes in Computer Science,
vol. 8293, Springer, Berlin, Heidelberg, 2014, pp. 44-77. https://doi.org/10.1007/978-3-642-54774-4_3
"""

def __init__(
self,
dim: Sequence[int],
mode: Literal['central', 'forward', 'backward'] = 'backward',
pad_mode: Literal['zeros', 'circular'] = 'zeros',
) -> None:
r"""Symmetrized gradient operator.

Parameters
----------
dim
Dimensions along which finite differences are calculated.
It must not contain the :math:`0^{\text{th}}` axis.
mode
Type of finite difference operator
pad_mode
Padding to ensure output has the same size as the input
"""
super().__init__()
finite_difference_op = FiniteDifferenceOp(dim, mode=mode, pad_mode=pad_mode)
transpose_op = RearrangeOp('sym_grad_dim grad_dim ... -> grad_dim sym_grad_dim ...')
self._operator = 0.5 * (1 + transpose_op) @ finite_difference_op

def __call__(self, v: torch.Tensor) -> tuple[torch.Tensor,]:
r"""Apply forward of symmetrized gradient operator.

The length of the first axis of ``v`` (``v.shape[0]``) must match
the number of dimensions specified in ``dim`` during initialization.

Parameters
----------
v
:math:`(d+1)`-dimensional input tensor with
the first dimension indexing the :math:`d` vector components.

Returns
-------
A single-element tuple (``w``, ) containing a :math:`(d+2)`-dimensional tensor ``w``
which represents the symmetrized gradient of ``v``.
"""
return super().__call__(v)

def forward(self, v: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply forward of SymmetrizedGradientOp.

.. note::
Prefer calling the instance of the SymmetrizedGradientOp operator as ``operator(x)`` over
directly calling this method. See this PyTorch `discussion <https://discuss.pytorch.org/t/is-model-forward-x-the-same-as-model-call-x/33460/3>`_.
"""
return self._operator(v)

def adjoint(self, w: torch.Tensor) -> tuple[torch.Tensor,]:
r"""Apply adjoint of symmetrized gradient operator.

The lengths of the first two axes of ``w`` (``w.shape[0]`` and ``w.shape[1]``) must equal each other and
match the number of dimensions specified in ``dim`` during initialization.

Parameters
----------
w
:math:`(d+2)`-dimensional input tensor representing the symmetrized gradient.

Returns
-------
A single-element tuple (``v``, ) containing the :math:`(d+1)`-dimensional tensor ``v``
which represents the adjoint of the symmetrized gradient.
"""
return self._operator.adjoint(w)
4 changes: 3 additions & 1 deletion src/mrpro/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from mrpro.operators.SensitivityOp import SensitivityOp
from mrpro.operators.SignalModel import SignalModel
from mrpro.operators.SliceProjectionOp import SliceProjectionOp
from mrpro.operators.SymmetrizedGradientOp import SymmetrizedGradientOp
from mrpro.operators.WaveletOp import WaveletOp
from mrpro.operators.ZeroPadOp import ZeroPadOp
from mrpro.operators.ZeroOp import ZeroOp
Expand Down Expand Up @@ -70,9 +71,10 @@
"SensitivityOp",
"SignalModel",
"SliceProjectionOp",
"SymmetrizedGradientOp",
"WaveletOp",
"ZeroOp",
"ZeroPadOp",
"functionals",
"models"
]
]
129 changes: 129 additions & 0 deletions tests/operators/test_symmetrized_gradient_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Tests for symmetrized gradient operator."""

from collections.abc import Sequence
from typing import Literal

import pytest
import torch
from einops import repeat
from mrpro.operators import SymmetrizedGradientOp
from mrpro.utils import RandomGenerator

from tests import (
dotproduct_adjointness_test,
forward_mode_autodiff_of_linear_operator_test,
gradient_of_linear_operator_test,
)


def create_symmetrized_gradient_op_and_range_domain(
dim: Sequence[int],
mode: Literal['central', 'forward', 'backward'],
pad_mode: Literal['zeros', 'circular'],
) -> tuple[SymmetrizedGradientOp, torch.Tensor, torch.Tensor]:
"""Create a symmetrized gradient operator and an element from domain and range."""
rng = RandomGenerator(seed=0)
input_shape = (len(dim), 6, 4, 10, 20, 16) # First dimension matches number of gradients

# Generate symmetrized gradient operator
symmetrized_gradient_op = SymmetrizedGradientOp(dim, mode, pad_mode)

u = rng.complex64_tensor(size=input_shape)
v = rng.complex64_tensor(size=(len(dim), *input_shape))
return symmetrized_gradient_op, u, v


@pytest.mark.parametrize('mode', ['central', 'forward', 'backward'])
def test_symmetrized_gradient_op_forward(
mode: Literal['central', 'forward', 'backward'],
) -> None:
"""Test symmetrized gradient of a simple linear vector field in 2D."""
# Create a test object v = (v0, v1) in 2D
size = 10
y_coords = repeat(torch.arange(size, dtype=torch.float32), 'y -> 1 y x', x=size)
x_coords = repeat(torch.arange(size, dtype=torch.float32), 'x -> 1 y x', y=size)
v0 = x_coords + 2 * y_coords # v0(y, x) = x + 2y
v1 = 2 * x_coords + y_coords # v1(y, x) = 2x + y
v0 = v0 - 1j * v0
v1 = v1 - 1j * v1
v = torch.cat((v0, v1), dim=0) # shape (2, size, size)

# Generate and apply symmetrized gradient operator
dim = (-1, -2)
assert v.shape[0] == len(dim) # Ensure first dimension of v matches number of dimensions in dim
sym_grad_op = SymmetrizedGradientOp(dim=dim, mode=mode)
(sym_grad,) = sym_grad_op(v) # shape (2, 2, size, size)

# Extract interior (remove borders to avoid boundary effects)
sym_00 = sym_grad[0, 0, 1:-1, 1:-1]
sym_11 = sym_grad[1, 1, 1:-1, 1:-1]
sym_01 = sym_grad[0, 1, 1:-1, 1:-1]
sym_10 = sym_grad[1, 0, 1:-1, 1:-1]

# Verify correct values excluding borders
torch.testing.assert_close(sym_00, (1 - 1j) * torch.ones_like(sym_00))
torch.testing.assert_close(sym_11, (1 - 1j) * torch.ones_like(sym_11))
torch.testing.assert_close(sym_01, (2 - 2j) * torch.ones_like(sym_01))
torch.testing.assert_close(sym_10, (2 - 2j) * torch.ones_like(sym_10))


@pytest.mark.parametrize('pad_mode', ['zeros', 'circular'])
@pytest.mark.parametrize('mode', ['central', 'forward', 'backward'])
@pytest.mark.parametrize('dim', [(-1,), (-2, -1), (-3, -2, -1), (-4,), (1, 3)])
def test_symmetrized_gradient_op_adjointness(
dim: Sequence[int],
mode: Literal['central', 'forward', 'backward'],
pad_mode: Literal['zeros', 'circular'],
) -> None:
"""Test symmetrized gradient operator adjoint property."""
dotproduct_adjointness_test(*create_symmetrized_gradient_op_and_range_domain(dim, mode, pad_mode))


@pytest.mark.parametrize('pad_mode', ['zeros', 'circular'])
@pytest.mark.parametrize('mode', ['central', 'forward', 'backward'])
@pytest.mark.parametrize('dim', [(-1,), (-2, -1), (-3, -2, -1), (-4,), (1, 3)])
def test_symmetrized_gradient_op_grad(
dim: Sequence[int],
mode: Literal['central', 'forward', 'backward'],
pad_mode: Literal['zeros', 'circular'],
) -> None:
"""Test the gradient of symmetrized gradient operator."""
gradient_of_linear_operator_test(*create_symmetrized_gradient_op_and_range_domain(dim, mode, pad_mode))


@pytest.mark.parametrize('pad_mode', ['zeros', 'circular'])
@pytest.mark.parametrize('mode', ['central', 'forward', 'backward'])
@pytest.mark.parametrize('dim', [(-1,), (-2, -1), (-3, -2, -1), (-4,), (1, 3)])
def test_symmetrized_gradient_op_forward_mode_autodiff(
dim: Sequence[int],
mode: Literal['central', 'forward', 'backward'],
pad_mode: Literal['zeros', 'circular'],
) -> None:
"""Test the forward-mode autodiff of the symmetrized gradient operator."""
forward_mode_autodiff_of_linear_operator_test(*create_symmetrized_gradient_op_and_range_domain(dim, mode, pad_mode))


@pytest.mark.cuda
def test_symmetrized_gradient_op_cuda() -> None:
"""Test symmetrized gradient operator works on CUDA devices."""

# Set dimensional parameters
dim = (-3, -2, -1)
input_shape = (len(dim), 6, 4, 10, 20, 16)

# Generate data
random_generator = RandomGenerator(seed=0)
u = random_generator.complex64_tensor(size=input_shape)

# Create on CPU, run on CPU
symmetrized_gradient_op = SymmetrizedGradientOp(dim, mode='central', pad_mode='circular')
operator = symmetrized_gradient_op.H @ symmetrized_gradient_op
(symmetrized_gradient_output,) = operator(u)
assert symmetrized_gradient_output.is_cpu

# Transfer to GPU, run on GPU
symmetrized_gradient_op = SymmetrizedGradientOp(dim, mode='central', pad_mode='circular')
operator = symmetrized_gradient_op.H @ symmetrized_gradient_op
operator.cuda()
(symmetrized_gradient_output,) = operator(u.cuda())
assert symmetrized_gradient_output.is_cuda