Skip to content
79 changes: 79 additions & 0 deletions src/mrpro/operators/models/PEX.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Saturation recovery signal model for T1 mapping."""

from collections.abc import Sequence

import torch

from mrpro.operators.SignalModel import SignalModel
from mrpro.utils import unsqueeze_right
from mrpro.utils.unit_conversion import GYROMAGNETIC_RATIO_PROTON, volt_to_sqrt_kwatt


class PEX(SignalModel[torch.Tensor, torch.Tensor]):
"""Signal model for preparation based B1+ mapping (PEX)."""

def __init__(
self,
voltages: float | torch.Tensor | Sequence[float],
prep_delay: float | torch.Tensor,
pulse_duration: float | torch.Tensor,
n_tx: int = 1,
) -> None:
"""Initialize preparation based B1+ mapping (PEX) signal model.

Parameters
----------
voltages
voltages. Shape `(voltage, ...)`.
prep_delay
preparation delay. Shape `(prepdelay, ...)`.
pulse_duration
rect pulse duration in seconds. Shape `(...)`.
n_tx
number of transmit channels.
"""
super().__init__()
voltages_ = torch.as_tensor(voltages) * n_tx**0.5
prep_delay_ = torch.as_tensor(prep_delay)
pulse_duration_ = torch.as_tensor(pulse_duration)
self.voltages = torch.nn.Parameter(voltages_, requires_grad=voltages_.requires_grad)
self.prep_delay = torch.nn.Parameter(prep_delay_, requires_grad=prep_delay_.requires_grad)
self.pulse_duration = torch.nn.Parameter(pulse_duration_, requires_grad=pulse_duration_.requires_grad)

def __call__(self, b1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply PEX signal model.

Parameters
----------
b1
B1+ in µT/sqrt(kW) translating voltage of the coil to flip angle
with shape `(...)`.
t1
Longitudinal relaxation time.

Returns
-------
signal with shape `(voltage/prepdelay, *other, coils, z, y, x)`
"""
return super().__call__(b1, t1)

def forward(self, b1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply PEX signal model.

.. note::
Prefer calling the instance of the PEX operator as ``operator(b1, t1)`` over
directly calling this method.
"""
ndim = b1.ndim
voltages = unsqueeze_right(self.voltages, ndim - self.voltages.ndim + 1) # +1 are voltages
prep_delay = unsqueeze_right(self.prep_delay, ndim - self.prep_delay.ndim + 1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me the test_pex_t1_recovery Test was failing due to dimension issues in the signal function (prep_delay and t1 are arrays). This is fixed with moving this to the same dimension as voltages, however I am not sure whether this is a good idea? Usually one of them should only be scalar, could add some assert statement.

pulse_duration = unsqueeze_right(self.pulse_duration, ndim - self.pulse_duration.ndim)

# this is mainly cos(FA), where FA = gamma * b1 * voltage * t
signal = 1 - (
1
- torch.cos(
b1 * volt_to_sqrt_kwatt(voltages) * 1e-6 * pulse_duration * GYROMAGNETIC_RATIO_PROTON * 2 * torch.pi
)
) * torch.exp(-prep_delay / t1)
return (signal,)
4 changes: 3 additions & 1 deletion src/mrpro/operators/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@
from mrpro.operators.models.cMRF import CardiacFingerprinting
from mrpro.operators.models.TransientSteadyStateWithPreparation import TransientSteadyStateWithPreparation
from mrpro.operators.models import EPG
from mrpro.operators.models.PEX import PEX

__all__ = [
"CardiacFingerprinting",
"EPG",
"InversionRecovery",
"MOLLI",
"MonoExponentialDecay",
"PEX",
"SaturationRecovery",
"SpoiledGRE",
"TransientSteadyStateWithPreparation",
"WASABI",
"WASABITI"
]
]
18 changes: 18 additions & 0 deletions src/mrpro/utils/unit_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,24 @@ def rad_to_deg(rad: T) -> T:
return rad * 180.0 / np.pi


def volt_to_sqrt_kwatt(volt: T) -> T:
"""Convert Volt to kilo Watt for 50 Ohm."""
if isinstance(volt, list):
return [volt_to_sqrt_kwatt(x) for x in volt]
if isinstance(volt, tuple):
return tuple([volt_to_sqrt_kwatt(x) for x in volt])
return volt / 50e3 ** (0.5)


def sqrt_kwatt_to_volt(sqrt_kwatt: T) -> T:
"""Convert kilo Watt to Volt for 50 Ohm."""
if isinstance(sqrt_kwatt, list):
return [sqrt_kwatt_to_volt(x) for x in sqrt_kwatt]
if isinstance(sqrt_kwatt, tuple):
return tuple([sqrt_kwatt_to_volt(x) for x in sqrt_kwatt])
return sqrt_kwatt * 50e3 ** (0.5)


def lamor_frequency_to_magnetic_field(lamor_frequency: T, gyromagnetic_ratio: float = GYROMAGNETIC_RATIO_PROTON) -> T:
"""Convert the Lamor frequency [Hz] to the magntic field strength [T].

Expand Down
164 changes: 164 additions & 0 deletions tests/operators/models/test_pex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Tests for PEX signal model."""

from collections.abc import Sequence

import pytest
import torch
from mrpro.operators.models import PEX
from mrpro.utils import RandomGenerator
from tests import autodiff_test
from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS


@pytest.mark.parametrize(
('voltages', 'pulse_duration', 'expected_behavior'),
[
(0, 1.0, 'unity'), # zero voltage should give signal close to 1
([0, 100, 1000], 0, 'unity'), # zero pulse duration should give signal close to 1
],
)
def test_pex_special_values(
voltages: float | list[float],
pulse_duration: float,
expected_behavior: str,
parameter_shape: Sequence[int] = (2, 5, 10, 10, 10),
) -> None:
"""Test PEX signal at special input values."""
rng = RandomGenerator(0)
prep_delay = 0.01 # short prep delay

model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration)
b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10) # µT/sqrt(kW)
t1 = rng.float32_tensor(parameter_shape, low=0.1, high=2)

(signal,) = model(b1, t1)

# For zero voltage or zero pulse duration, signal should be close to 1
if expected_behavior == 'unity':
expected = torch.ones_like(signal)
torch.testing.assert_close(signal, expected, atol=1e-3, rtol=1e-3)


def test_pex_flip_angle_behavior(parameter_shape: Sequence[int] = (2, 5, 10)) -> None:
"""Test PEX signal behavior with increasing flip angles."""
rng = RandomGenerator(1)
voltages = [10, 50, 100]
prep_delay = 0.01
pulse_duration = 0.001
model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration)
b1 = rng.float32_tensor(parameter_shape, low=1, high=5) # µT/sqrt(kW)
t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2)

(signal,) = model(b1, t1)

# Signal should decrease with increasing voltage (higher flip angles)
assert torch.all(signal[0] >= signal[1]) # first voltage < second voltage
assert torch.all(signal[1] >= signal[2]) # second voltage < third voltage
assert signal.isfinite().all()


def test_pex_t1_recovery(parameter_shape: Sequence[int] = (2, 5, 10)) -> None:
"""Test PEX signal T1 recovery behavior."""
rng = RandomGenerator(2)
voltages = 100 # fixed voltage
prep_delay = torch.tensor([0.001, 0.01, 0.1, 1.0]) # increasing prep delays
pulse_duration = 0.001

model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration)
b1 = rng.float32_tensor(parameter_shape, low=1, high=5)
t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2)

(signal,) = model(b1, t1)

# Signal should increase with longer prep delay (more T1 recovery)
for i in range(len(prep_delay) - 1):
assert torch.all(signal[i] <= signal[i + 1])
assert signal.isfinite().all()


@SHAPE_VARIATIONS_SIGNAL_MODELS
def test_pex_shape(
parameter_shape: Sequence[int], contrast_dim_shape: Sequence[int], signal_shape: Sequence[int]
) -> None:
"""Test correct signal shapes."""
rng = RandomGenerator(1)
voltages = rng.float32_tensor(contrast_dim_shape, low=0, high=200)
prep_delay = 0.01
pulse_duration = 0.001

model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration)
b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10)
t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2)
(signal,) = model(b1, t1)
assert signal.shape == signal_shape
assert signal.isfinite().all()


def test_autodiff_pex(
parameter_shape: Sequence[int] = (2, 5, 10),
contrast_dim_shape: Sequence[int] = (13, 2, 5, 10),
) -> None:
"""Test autodiff works for PEX model."""
rng = RandomGenerator(2)
voltages = rng.float32_tensor(contrast_dim_shape, low=0, high=200)
prep_delay = 0.01
pulse_duration = 0.001

model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration)
b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10)
t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2)
autodiff_test(model, b1, t1)


@pytest.mark.cuda
def test_pex_cuda(parameter_shape: Sequence[int] = (2, 5), contrast_dim_shape: Sequence[int] = (13, 2, 5)) -> None:
"""Test the PEX model works on cuda devices."""
rng = RandomGenerator(3)
voltages = rng.float32_tensor(contrast_dim_shape, low=0, high=200)
prep_delay = 0.01
pulse_duration = 0.001
b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10)
t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2)

# Create on CPU, transfer to GPU and run on GPU
model = PEX(voltages=voltages.tolist(), prep_delay=prep_delay, pulse_duration=pulse_duration)
model.cuda()
(signal,) = model(b1.cuda(), t1.cuda())
assert signal.is_cuda
assert signal.isfinite().all()

# Create on GPU and run on GPU
model = PEX(voltages=voltages.cuda(), prep_delay=prep_delay, pulse_duration=pulse_duration)
(signal,) = model(b1.cuda(), t1.cuda())
assert signal.is_cuda
assert signal.isfinite().all()

# Create on GPU, transfer to CPU and run on CPU
model = PEX(voltages=voltages.cuda(), prep_delay=prep_delay, pulse_duration=pulse_duration)
model.cpu()
(signal,) = model(b1, t1)
assert signal.is_cpu
assert signal.isfinite().all()


def test_pex_n_tx_scaling(parameter_shape: Sequence[int] = (2, 5, 10)) -> None:
"""Test PEX signal scales correctly with number of transmit channels."""
rng = RandomGenerator(4)
voltages = 100
prep_delay = 0.01
pulse_duration = 0.001
b1 = rng.float32_tensor(parameter_shape, low=1, high=5)
t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2)

# Test with different n_tx values
model_1tx = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration, n_tx=1)
model_4tx = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration, n_tx=4)

(signal_1tx,) = model_1tx(b1, t1)
(signal_4tx,) = model_4tx(b1, t1)

# With higher n_tx, the effective voltage is scaled by sqrt(n_tx), so flip angle increases
# This should result in lower signal values
assert torch.all(signal_4tx <= signal_1tx)
assert signal_1tx.isfinite().all()
assert signal_4tx.isfinite().all()
51 changes: 51 additions & 0 deletions tests/utils/test_unit_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
ms_to_s,
rad_to_deg,
s_to_ms,
sqrt_kwatt_to_volt,
volt_to_sqrt_kwatt,
)


Expand Down Expand Up @@ -95,3 +97,52 @@ def test_magnetic_field_to_lamor_frequency():
proton_gyromagnetic_ratio = 42.58 * 1e6
magnetic_field_strength = 3.0
assert magnetic_field_to_lamor_frequency(magnetic_field_strength, proton_gyromagnetic_ratio) == 127.74 * 1e6


def test_volt_to_sqrt_kwatt():
"""Verify Volt to sqrt(kW) conversion."""
rng = RandomGenerator(seed=0)
volt_input = rng.float32_tensor((3, 4, 5))
expected = volt_input / (50e3**0.5)
torch.testing.assert_close(volt_to_sqrt_kwatt(volt_input), expected)


def test_sqrt_kwatt_to_volt():
"""Verify sqrt(kW) to Volt conversion."""
rng = RandomGenerator(seed=0)
sqrt_kwatt_input = rng.float32_tensor((3, 4, 5))
expected = sqrt_kwatt_input * (50e3**0.5)
torch.testing.assert_close(sqrt_kwatt_to_volt(sqrt_kwatt_input), expected)


def test_volt_sqrt_kwatt_round_trip():
"""Verify Volt <-> sqrt(kW) conversions are inverse operations."""
rng = RandomGenerator(seed=42)
volt_input = rng.float32_tensor((3, 4, 5), low=1.0, high=1000.0)

# Test round trip: volt -> sqrt_kwatt -> volt
sqrt_kwatt_converted = volt_to_sqrt_kwatt(volt_input)
volt_recovered = sqrt_kwatt_to_volt(sqrt_kwatt_converted)
torch.testing.assert_close(volt_input, volt_recovered)

# Test round trip: sqrt_kwatt -> volt -> sqrt_kwatt
sqrt_kwatt_input = rng.float32_tensor((3, 4, 5), low=0.001, high=1.0)
volt_converted = sqrt_kwatt_to_volt(sqrt_kwatt_input)
sqrt_kwatt_recovered = volt_to_sqrt_kwatt(volt_converted)
torch.testing.assert_close(sqrt_kwatt_input, sqrt_kwatt_recovered)


def test_volt_to_sqrt_kwatt_scalar():
"""Verify Volt to sqrt(kW) conversion for scalar values."""
volt = 100.0
expected = volt / (50e3**0.5)
result = volt_to_sqrt_kwatt(volt)
assert abs(result - expected) < 1e-6


def test_sqrt_kwatt_to_volt_scalar():
"""Verify sqrt(kW) to Volt conversion for scalar values."""
sqrt_kwatt = 0.5
expected = sqrt_kwatt * (50e3**0.5)
result = sqrt_kwatt_to_volt(sqrt_kwatt)
assert abs(result - expected) < 1e-6
Loading