diff --git a/src/mrpro/operators/models/PEX.py b/src/mrpro/operators/models/PEX.py new file mode 100644 index 000000000..2ef50db67 --- /dev/null +++ b/src/mrpro/operators/models/PEX.py @@ -0,0 +1,79 @@ +"""Saturation recovery signal model for T1 mapping.""" + +from collections.abc import Sequence + +import torch + +from mrpro.operators.SignalModel import SignalModel +from mrpro.utils import unsqueeze_right +from mrpro.utils.unit_conversion import GYROMAGNETIC_RATIO_PROTON, volt_to_sqrt_kwatt + + +class PEX(SignalModel[torch.Tensor, torch.Tensor]): + """Signal model for preparation based B1+ mapping (PEX).""" + + def __init__( + self, + voltages: float | torch.Tensor | Sequence[float], + prep_delay: float | torch.Tensor, + pulse_duration: float | torch.Tensor, + n_tx: int = 1, + ) -> None: + """Initialize preparation based B1+ mapping (PEX) signal model. + + Parameters + ---------- + voltages + voltages. Shape `(voltage, ...)`. + prep_delay + preparation delay. Shape `(prepdelay, ...)`. + pulse_duration + rect pulse duration in seconds. Shape `(...)`. + n_tx + number of transmit channels. + """ + super().__init__() + voltages_ = torch.as_tensor(voltages) * n_tx**0.5 + prep_delay_ = torch.as_tensor(prep_delay) + pulse_duration_ = torch.as_tensor(pulse_duration) + self.voltages = torch.nn.Parameter(voltages_, requires_grad=voltages_.requires_grad) + self.prep_delay = torch.nn.Parameter(prep_delay_, requires_grad=prep_delay_.requires_grad) + self.pulse_duration = torch.nn.Parameter(pulse_duration_, requires_grad=pulse_duration_.requires_grad) + + def __call__(self, b1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply PEX signal model. + + Parameters + ---------- + b1 + B1+ in µT/sqrt(kW) translating voltage of the coil to flip angle + with shape `(...)`. + t1 + Longitudinal relaxation time. + + Returns + ------- + signal with shape `(voltage/prepdelay, *other, coils, z, y, x)` + """ + return super().__call__(b1, t1) + + def forward(self, b1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply PEX signal model. + + .. note:: + Prefer calling the instance of the PEX operator as ``operator(b1, t1)`` over + directly calling this method. + """ + ndim = b1.ndim + voltages = unsqueeze_right(self.voltages, ndim - self.voltages.ndim + 1) # +1 are voltages + prep_delay = unsqueeze_right(self.prep_delay, ndim - self.prep_delay.ndim + 1) + pulse_duration = unsqueeze_right(self.pulse_duration, ndim - self.pulse_duration.ndim) + + # this is mainly cos(FA), where FA = gamma * b1 * voltage * t + signal = 1 - ( + 1 + - torch.cos( + b1 * volt_to_sqrt_kwatt(voltages) * 1e-6 * pulse_duration * GYROMAGNETIC_RATIO_PROTON * 2 * torch.pi + ) + ) * torch.exp(-prep_delay / t1) + return (signal,) diff --git a/src/mrpro/operators/models/__init__.py b/src/mrpro/operators/models/__init__.py index d2ba22340..7a4ed5ed6 100644 --- a/src/mrpro/operators/models/__init__.py +++ b/src/mrpro/operators/models/__init__.py @@ -10,6 +10,7 @@ from mrpro.operators.models.cMRF import CardiacFingerprinting from mrpro.operators.models.TransientSteadyStateWithPreparation import TransientSteadyStateWithPreparation from mrpro.operators.models import EPG +from mrpro.operators.models.PEX import PEX __all__ = [ "CardiacFingerprinting", @@ -17,9 +18,10 @@ "InversionRecovery", "MOLLI", "MonoExponentialDecay", + "PEX", "SaturationRecovery", "SpoiledGRE", "TransientSteadyStateWithPreparation", "WASABI", "WASABITI" -] +] \ No newline at end of file diff --git a/src/mrpro/utils/unit_conversion.py b/src/mrpro/utils/unit_conversion.py index fefaddb8d..0cad77ce5 100644 --- a/src/mrpro/utils/unit_conversion.py +++ b/src/mrpro/utils/unit_conversion.py @@ -103,6 +103,24 @@ def rad_to_deg(rad: T) -> T: return rad * 180.0 / np.pi +def volt_to_sqrt_kwatt(volt: T) -> T: + """Convert Volt to kilo Watt for 50 Ohm.""" + if isinstance(volt, list): + return [volt_to_sqrt_kwatt(x) for x in volt] + if isinstance(volt, tuple): + return tuple([volt_to_sqrt_kwatt(x) for x in volt]) + return volt / 50e3 ** (0.5) + + +def sqrt_kwatt_to_volt(sqrt_kwatt: T) -> T: + """Convert kilo Watt to Volt for 50 Ohm.""" + if isinstance(sqrt_kwatt, list): + return [sqrt_kwatt_to_volt(x) for x in sqrt_kwatt] + if isinstance(sqrt_kwatt, tuple): + return tuple([sqrt_kwatt_to_volt(x) for x in sqrt_kwatt]) + return sqrt_kwatt * 50e3 ** (0.5) + + def lamor_frequency_to_magnetic_field(lamor_frequency: T, gyromagnetic_ratio: float = GYROMAGNETIC_RATIO_PROTON) -> T: """Convert the Lamor frequency [Hz] to the magntic field strength [T]. diff --git a/tests/operators/models/test_pex.py b/tests/operators/models/test_pex.py new file mode 100644 index 000000000..3dd994999 --- /dev/null +++ b/tests/operators/models/test_pex.py @@ -0,0 +1,164 @@ +"""Tests for PEX signal model.""" + +from collections.abc import Sequence + +import pytest +import torch +from mrpro.operators.models import PEX +from mrpro.utils import RandomGenerator +from tests import autodiff_test +from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS + + +@pytest.mark.parametrize( + ('voltages', 'pulse_duration', 'expected_behavior'), + [ + (0, 1.0, 'unity'), # zero voltage should give signal close to 1 + ([0, 100, 1000], 0, 'unity'), # zero pulse duration should give signal close to 1 + ], +) +def test_pex_special_values( + voltages: float | list[float], + pulse_duration: float, + expected_behavior: str, + parameter_shape: Sequence[int] = (2, 5, 10, 10, 10), +) -> None: + """Test PEX signal at special input values.""" + rng = RandomGenerator(0) + prep_delay = 0.01 # short prep delay + + model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) + b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10) # µT/sqrt(kW) + t1 = rng.float32_tensor(parameter_shape, low=0.1, high=2) + + (signal,) = model(b1, t1) + + # For zero voltage or zero pulse duration, signal should be close to 1 + if expected_behavior == 'unity': + expected = torch.ones_like(signal) + torch.testing.assert_close(signal, expected, atol=1e-3, rtol=1e-3) + + +def test_pex_flip_angle_behavior(parameter_shape: Sequence[int] = (2, 5, 10)) -> None: + """Test PEX signal behavior with increasing flip angles.""" + rng = RandomGenerator(1) + voltages = [10, 50, 100] + prep_delay = 0.01 + pulse_duration = 0.001 + model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) + b1 = rng.float32_tensor(parameter_shape, low=1, high=5) # µT/sqrt(kW) + t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2) + + (signal,) = model(b1, t1) + + # Signal should decrease with increasing voltage (higher flip angles) + assert torch.all(signal[0] >= signal[1]) # first voltage < second voltage + assert torch.all(signal[1] >= signal[2]) # second voltage < third voltage + assert signal.isfinite().all() + + +def test_pex_t1_recovery(parameter_shape: Sequence[int] = (2, 5, 10)) -> None: + """Test PEX signal T1 recovery behavior.""" + rng = RandomGenerator(2) + voltages = 100 # fixed voltage + prep_delay = torch.tensor([0.001, 0.01, 0.1, 1.0]) # increasing prep delays + pulse_duration = 0.001 + + model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) + b1 = rng.float32_tensor(parameter_shape, low=1, high=5) + t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2) + + (signal,) = model(b1, t1) + + # Signal should increase with longer prep delay (more T1 recovery) + for i in range(len(prep_delay) - 1): + assert torch.all(signal[i] <= signal[i + 1]) + assert signal.isfinite().all() + + +@SHAPE_VARIATIONS_SIGNAL_MODELS +def test_pex_shape( + parameter_shape: Sequence[int], contrast_dim_shape: Sequence[int], signal_shape: Sequence[int] +) -> None: + """Test correct signal shapes.""" + rng = RandomGenerator(1) + voltages = rng.float32_tensor(contrast_dim_shape, low=0, high=200) + prep_delay = 0.01 + pulse_duration = 0.001 + + model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) + b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10) + t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2) + (signal,) = model(b1, t1) + assert signal.shape == signal_shape + assert signal.isfinite().all() + + +def test_autodiff_pex( + parameter_shape: Sequence[int] = (2, 5, 10), + contrast_dim_shape: Sequence[int] = (13, 2, 5, 10), +) -> None: + """Test autodiff works for PEX model.""" + rng = RandomGenerator(2) + voltages = rng.float32_tensor(contrast_dim_shape, low=0, high=200) + prep_delay = 0.01 + pulse_duration = 0.001 + + model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) + b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10) + t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2) + autodiff_test(model, b1, t1) + + +@pytest.mark.cuda +def test_pex_cuda(parameter_shape: Sequence[int] = (2, 5), contrast_dim_shape: Sequence[int] = (13, 2, 5)) -> None: + """Test the PEX model works on cuda devices.""" + rng = RandomGenerator(3) + voltages = rng.float32_tensor(contrast_dim_shape, low=0, high=200) + prep_delay = 0.01 + pulse_duration = 0.001 + b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10) + t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2) + + # Create on CPU, transfer to GPU and run on GPU + model = PEX(voltages=voltages.tolist(), prep_delay=prep_delay, pulse_duration=pulse_duration) + model.cuda() + (signal,) = model(b1.cuda(), t1.cuda()) + assert signal.is_cuda + assert signal.isfinite().all() + + # Create on GPU and run on GPU + model = PEX(voltages=voltages.cuda(), prep_delay=prep_delay, pulse_duration=pulse_duration) + (signal,) = model(b1.cuda(), t1.cuda()) + assert signal.is_cuda + assert signal.isfinite().all() + + # Create on GPU, transfer to CPU and run on CPU + model = PEX(voltages=voltages.cuda(), prep_delay=prep_delay, pulse_duration=pulse_duration) + model.cpu() + (signal,) = model(b1, t1) + assert signal.is_cpu + assert signal.isfinite().all() + + +def test_pex_n_tx_scaling(parameter_shape: Sequence[int] = (2, 5, 10)) -> None: + """Test PEX signal scales correctly with number of transmit channels.""" + rng = RandomGenerator(4) + voltages = 100 + prep_delay = 0.01 + pulse_duration = 0.001 + b1 = rng.float32_tensor(parameter_shape, low=1, high=5) + t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2) + + # Test with different n_tx values + model_1tx = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration, n_tx=1) + model_4tx = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration, n_tx=4) + + (signal_1tx,) = model_1tx(b1, t1) + (signal_4tx,) = model_4tx(b1, t1) + + # With higher n_tx, the effective voltage is scaled by sqrt(n_tx), so flip angle increases + # This should result in lower signal values + assert torch.all(signal_4tx <= signal_1tx) + assert signal_1tx.isfinite().all() + assert signal_4tx.isfinite().all() diff --git a/tests/utils/test_unit_conversion.py b/tests/utils/test_unit_conversion.py index 0976472fc..afa58457a 100644 --- a/tests/utils/test_unit_conversion.py +++ b/tests/utils/test_unit_conversion.py @@ -14,6 +14,8 @@ ms_to_s, rad_to_deg, s_to_ms, + sqrt_kwatt_to_volt, + volt_to_sqrt_kwatt, ) @@ -95,3 +97,52 @@ def test_magnetic_field_to_lamor_frequency(): proton_gyromagnetic_ratio = 42.58 * 1e6 magnetic_field_strength = 3.0 assert magnetic_field_to_lamor_frequency(magnetic_field_strength, proton_gyromagnetic_ratio) == 127.74 * 1e6 + + +def test_volt_to_sqrt_kwatt(): + """Verify Volt to sqrt(kW) conversion.""" + rng = RandomGenerator(seed=0) + volt_input = rng.float32_tensor((3, 4, 5)) + expected = volt_input / (50e3**0.5) + torch.testing.assert_close(volt_to_sqrt_kwatt(volt_input), expected) + + +def test_sqrt_kwatt_to_volt(): + """Verify sqrt(kW) to Volt conversion.""" + rng = RandomGenerator(seed=0) + sqrt_kwatt_input = rng.float32_tensor((3, 4, 5)) + expected = sqrt_kwatt_input * (50e3**0.5) + torch.testing.assert_close(sqrt_kwatt_to_volt(sqrt_kwatt_input), expected) + + +def test_volt_sqrt_kwatt_round_trip(): + """Verify Volt <-> sqrt(kW) conversions are inverse operations.""" + rng = RandomGenerator(seed=42) + volt_input = rng.float32_tensor((3, 4, 5), low=1.0, high=1000.0) + + # Test round trip: volt -> sqrt_kwatt -> volt + sqrt_kwatt_converted = volt_to_sqrt_kwatt(volt_input) + volt_recovered = sqrt_kwatt_to_volt(sqrt_kwatt_converted) + torch.testing.assert_close(volt_input, volt_recovered) + + # Test round trip: sqrt_kwatt -> volt -> sqrt_kwatt + sqrt_kwatt_input = rng.float32_tensor((3, 4, 5), low=0.001, high=1.0) + volt_converted = sqrt_kwatt_to_volt(sqrt_kwatt_input) + sqrt_kwatt_recovered = volt_to_sqrt_kwatt(volt_converted) + torch.testing.assert_close(sqrt_kwatt_input, sqrt_kwatt_recovered) + + +def test_volt_to_sqrt_kwatt_scalar(): + """Verify Volt to sqrt(kW) conversion for scalar values.""" + volt = 100.0 + expected = volt / (50e3**0.5) + result = volt_to_sqrt_kwatt(volt) + assert abs(result - expected) < 1e-6 + + +def test_sqrt_kwatt_to_volt_scalar(): + """Verify sqrt(kW) to Volt conversion for scalar values.""" + sqrt_kwatt = 0.5 + expected = sqrt_kwatt * (50e3**0.5) + result = sqrt_kwatt_to_volt(sqrt_kwatt) + assert abs(result - expected) < 1e-6