From 1817365d1edb2535624514c617aac290aac72e2d Mon Sep 17 00:00:00 2001 From: mxlutz Date: Thu, 15 May 2025 19:20:07 +0200 Subject: [PATCH 1/9] first version, needs clean up, fit function simple for slow pex --- examples/scripts/pex_mrpro.py | 123 +++++++++++++++++++++++++ src/mrpro/operators/models/PEX.py | 74 +++++++++++++++ src/mrpro/operators/models/__init__.py | 4 +- src/mrpro/utils/unit_conversion.py | 18 ++++ 4 files changed, 218 insertions(+), 1 deletion(-) create mode 100755 examples/scripts/pex_mrpro.py create mode 100644 src/mrpro/operators/models/PEX.py diff --git a/examples/scripts/pex_mrpro.py b/examples/scripts/pex_mrpro.py new file mode 100755 index 000000000..e433fe432 --- /dev/null +++ b/examples/scripts/pex_mrpro.py @@ -0,0 +1,123 @@ +# %% +import matplotlib.pyplot as plt +import mrpro +import torch +from mrpro.utils.unit_conversion import GYROMAGNETIC_RATIO_PROTON, sqrt_kwatt_to_volt +from pathlib import Path + +data_folder = Path('RAW') +# %% Reco +# Read raw data and trajectory + +kdata = mrpro.data.KData.from_file( + data_folder / 'meas_MID59_PEX_slow_B1_500us_RECT_60V_FID18835.h5', + mrpro.data.traj_calculators.KTrajectoryCartesian(), +).remove_readout_os() + +csm = mrpro.data.CsmData.from_kdata_inati(kdata[12]) +reco = mrpro.algorithms.reconstruction.DirectReconstruction(kdata, csm=csm) +img = reco(kdata).data +img = img.flip(dims=(0,)) + + +def get_pex_special_tab(kdata): + voltages = [] + for i in range(kdata.shape[0]): + voltages.append(kdata.header._misc['userParameters']['userParameterDouble'][68 + i]['value']) + + pulse_duration = kdata.header._misc['userParameters']['userParameterDouble'][18]['value'] * 1e-6 + prep_delay = kdata.header._misc['userParameters']['userParameterDouble'][11]['value'] * 1e-6 + return voltages[::-1], pulse_duration, prep_delay + + +voltages, pulse_duration, prep_delay = get_pex_special_tab(kdata) +print(f'voltages: {voltages} V') +print(f'pulse_duration: {pulse_duration} s') +print(f'prep_delay: {prep_delay} s') + +# %% + +img_fit = img / img[0, ...] + +img_abs = img_fit.abs().squeeze() +img_phase = img_fit.angle().squeeze() +img_sign = torch.ones_like(img_abs) +img_sign[img_phase.abs() > torch.pi / 2] = -1 +img_abs_sign = img_abs * img_sign + +# Plot coil-combined images for PEX +fig, axs = plt.subplots(2, img.shape[0], figsize=(10, 3)) +for i in range(img.shape[0]): + im1 = axs[0, i].imshow(img_abs[i, ...], cmap='gray', vmin=0, vmax=1) + im2 = axs[1, i].imshow(img_phase[i, ...], cmap='turbo', vmin=-torch.pi, vmax=torch.pi) + +fig.colorbar(im1, ax=axs[0, -1]) +fig.colorbar(im2, ax=axs[1, -1]) +plt.show() + +# Plot coil-combined images for PEX +fig, axs = plt.subplots(1, img.shape[0], figsize=(10, 3)) +for i in range(img.shape[0]): + im1 = axs[i].imshow(img_abs_sign[i, ...], cmap='seismic', vmin=-1, vmax=1) +fig.colorbar(im1, ax=axs[-1]) +plt.show() + +# plot signal of one pixel +plt.figure() +plt.plot(voltages, img_abs_sign[:, 32, 32], marker='o', linestyle='None') +plt.show() +# %% define signal model +model_op = mrpro.operators.models.PexSimple(voltages, prep_delay, t1=1, pulse_duration=pulse_duration, n_tx=8) + +# test signal model +# zero crossing should be close to 90/a +a = torch.tensor([2, 4, 10]) +(signal,) = model_op.forward(a) + +plt.figure() +plt.plot(voltages, signal[..., 2], marker='o', linestyle='None') +plt.show() + +# %% +dictionary = mrpro.operators.DictionaryMatchOp(model_op).append(torch.linspace(1, 100, 10000)) +(a_start,) = dictionary(img_abs_sign) + +# mask out data from approx 135° +voltage_mask = sqrt_kwatt_to_volt( + (3 / 4 * torch.pi) / (GYROMAGNETIC_RATIO_PROTON * 2 * torch.pi * a_start * 1e-6 * pulse_duration) +) * 8 ** (-0.5) + +weight = torch.ones(13, 64, 64) +for i in range(voltage_mask.shape[0]): + for j in range(voltage_mask.shape[1]): + voltages_tensor = torch.tensor(voltages) if isinstance(voltages, list) else voltages + weight[:, i, j] = voltages_tensor < voltage_mask[i, j] + +mse_loss = mrpro.operators.functionals.MSE(img_abs_sign, weight=weight) + +constraints_op = mrpro.operators.ConstraintsOp( + bounds=( + (0, 100), # a is constrained between 1 and 100 µT/sqrt(kW) + ) +) +functional = mse_loss @ model_op @ constraints_op +initial_parameters = constraints_op.inverse(a_start) + + +(result,) = constraints_op(*mrpro.algorithms.optimizers.lbfgs(functional, initial_parameters=initial_parameters)) +result = result.detach().cpu().squeeze() +# %% +plt.figure() +plt.imshow(result, vmin=0, vmax=60) +plt.colorbar() + +plt_idx = [20, 20] + +plt.figure() +plt.plot(voltages, img_abs_sign[:, plt_idx[0], plt_idx[1]], marker='o', linestyle='None', label='data') +plt.plot(voltages, model_op.forward(result[plt_idx[0], plt_idx[1]])[0], marker='o', linestyle='None', label='fit') +plt.plot(voltages, model_op.forward(a_start[plt_idx[0], plt_idx[1]])[0], marker='o', linestyle='None', label='initial') +plt.legend() +plt.show() + +# %% diff --git a/src/mrpro/operators/models/PEX.py b/src/mrpro/operators/models/PEX.py new file mode 100644 index 000000000..55ecd0409 --- /dev/null +++ b/src/mrpro/operators/models/PEX.py @@ -0,0 +1,74 @@ +"""Saturation recovery signal model for T1 mapping.""" + +from collections.abc import Sequence + +import torch + +from mrpro.operators.SignalModel import SignalModel +from mrpro.utils import unsqueeze_right +from mrpro.utils.unit_conversion import GYROMAGNETIC_RATIO_PROTON, volt_to_sqrt_kwatt + + +class PexSimple(SignalModel[torch.Tensor, torch.Tensor]): + """Signal model for preparation based B1+ mapping (PEX).""" + + def __init__( + self, + voltages: float | torch.Tensor | Sequence[int], + prep_delay: float | torch.Tensor, + t1: float | torch.Tensor, + pulse_duration: float | torch.Tensor, + n_tx: int = 1, + ) -> None: + """Initialize preparation based B1+ mapping (PEX) signal model. + + Parameters + ---------- + voltages + voltages. Shape `(Voltages, ...)`. + prep_delay + preparation delay. Shape `(1, ...)`. + t1 + longitudinal relaxation time T1. Shape `(1, ...)`. + pulse_duration + rect pulse duration in seconds. Shape `(1, ...)`. + n_tx + number of transmit channels. + """ + super().__init__() + voltages = torch.as_tensor(voltages) * torch.sqrt(torch.tensor(n_tx, dtype=torch.float)) + prep_delay = torch.as_tensor(prep_delay) + t1 = torch.as_tensor(t1) + pulse_duration = torch.as_tensor(pulse_duration) + self.voltages = torch.nn.Parameter(voltages, requires_grad=voltages.requires_grad) + self.prep_delay = torch.nn.Parameter(prep_delay, requires_grad=prep_delay.requires_grad) + self.t1 = torch.nn.Parameter(t1, requires_grad=t1.requires_grad) + self.pulse_duration = torch.nn.Parameter(pulse_duration, requires_grad=pulse_duration.requires_grad) + + def forward(self, a: torch.tensor) -> tuple[torch.Tensor,]: + """Apply PEX signal model. + + Parameters + ---------- + a + parameter a in µT/sqrt(kW) translating voltage of the coil to flip angle + with shape `(*other, coils, z, y, x)` + + Returns + ------- + signal with shape `(voltage, *other, coils, z, y, x)` + """ + ndim = a.ndim + voltages = unsqueeze_right(self.voltages, ndim - self.voltages.ndim + 1) + prep_delay = unsqueeze_right(self.prep_delay, ndim - self.prep_delay.ndim + 1) + t1 = unsqueeze_right(self.t1, ndim - self.t1.ndim + 1) + pulse_duration = unsqueeze_right(self.pulse_duration, ndim - self.pulse_duration.ndim + 1) + + # this is mainly cos(FA), where FA = gamma * a * voltage * t + signal = 1 - ( + 1 + - torch.cos( + a * volt_to_sqrt_kwatt(voltages) * 1e-6 * pulse_duration * GYROMAGNETIC_RATIO_PROTON * 2 * torch.pi + ) + ) * torch.exp(-prep_delay / t1) + return (signal,) diff --git a/src/mrpro/operators/models/__init__.py b/src/mrpro/operators/models/__init__.py index d2ba22340..4be8574e2 100644 --- a/src/mrpro/operators/models/__init__.py +++ b/src/mrpro/operators/models/__init__.py @@ -10,6 +10,7 @@ from mrpro.operators.models.cMRF import CardiacFingerprinting from mrpro.operators.models.TransientSteadyStateWithPreparation import TransientSteadyStateWithPreparation from mrpro.operators.models import EPG +from mrpro.operators.models.PEX import PexSimple __all__ = [ "CardiacFingerprinting", @@ -21,5 +22,6 @@ "SpoiledGRE", "TransientSteadyStateWithPreparation", "WASABI", - "WASABITI" + "WASABITI", + "PEX" ] diff --git a/src/mrpro/utils/unit_conversion.py b/src/mrpro/utils/unit_conversion.py index 3b690ecf1..ff433dc13 100644 --- a/src/mrpro/utils/unit_conversion.py +++ b/src/mrpro/utils/unit_conversion.py @@ -83,6 +83,24 @@ def rad_to_deg(rad: T) -> T: return rad * 180.0 / np.pi +def volt_to_sqrt_kwatt(volt: T) -> T: + """Convert Volt to kilo Watt for 50 Ohm.""" + if isinstance(volt, list): + return [volt_to_sqrt_kwatt(x) for x in volt] + if isinstance(volt, tuple): + return tuple([volt_to_sqrt_kwatt(x) for x in volt]) + return volt / 50e3 ** (0.5) + + +def sqrt_kwatt_to_volt(sqrt_kwatt: T) -> T: + """Convert kilo Watt to Volt for 50 Ohm.""" + if isinstance(sqrt_kwatt, list): + return [sqrt_kwatt_to_volt(x) for x in kwatt] + if isinstance(sqrt_kwatt, tuple): + return tuple([sqrt_kwatt_to_volt(x) for x in sqrt_kwatt]) + return sqrt_kwatt * 50e3 ** (0.5) + + def lamor_frequency_to_magnetic_field(lamor_frequency: T, gyromagnetic_ratio: float = GYROMAGNETIC_RATIO_PROTON) -> T: """Convert the Lamor frequency [Hz] to the magntic field strength [T]. From 62a47010e89c0765a5156869afa9e2187ed76778 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 17 Jun 2025 00:18:47 +0200 Subject: [PATCH 2/9] add tests. clean up --- examples/scripts/pex_mrpro.py | 123 ------------------------- src/mrpro/operators/models/PEX.py | 32 +++---- src/mrpro/operators/models/__init__.py | 8 +- src/mrpro/utils/unit_conversion.py | 2 +- tests/utils/test_unit_conversion.py | 51 ++++++++++ 5 files changed, 71 insertions(+), 145 deletions(-) delete mode 100755 examples/scripts/pex_mrpro.py diff --git a/examples/scripts/pex_mrpro.py b/examples/scripts/pex_mrpro.py deleted file mode 100755 index e433fe432..000000000 --- a/examples/scripts/pex_mrpro.py +++ /dev/null @@ -1,123 +0,0 @@ -# %% -import matplotlib.pyplot as plt -import mrpro -import torch -from mrpro.utils.unit_conversion import GYROMAGNETIC_RATIO_PROTON, sqrt_kwatt_to_volt -from pathlib import Path - -data_folder = Path('RAW') -# %% Reco -# Read raw data and trajectory - -kdata = mrpro.data.KData.from_file( - data_folder / 'meas_MID59_PEX_slow_B1_500us_RECT_60V_FID18835.h5', - mrpro.data.traj_calculators.KTrajectoryCartesian(), -).remove_readout_os() - -csm = mrpro.data.CsmData.from_kdata_inati(kdata[12]) -reco = mrpro.algorithms.reconstruction.DirectReconstruction(kdata, csm=csm) -img = reco(kdata).data -img = img.flip(dims=(0,)) - - -def get_pex_special_tab(kdata): - voltages = [] - for i in range(kdata.shape[0]): - voltages.append(kdata.header._misc['userParameters']['userParameterDouble'][68 + i]['value']) - - pulse_duration = kdata.header._misc['userParameters']['userParameterDouble'][18]['value'] * 1e-6 - prep_delay = kdata.header._misc['userParameters']['userParameterDouble'][11]['value'] * 1e-6 - return voltages[::-1], pulse_duration, prep_delay - - -voltages, pulse_duration, prep_delay = get_pex_special_tab(kdata) -print(f'voltages: {voltages} V') -print(f'pulse_duration: {pulse_duration} s') -print(f'prep_delay: {prep_delay} s') - -# %% - -img_fit = img / img[0, ...] - -img_abs = img_fit.abs().squeeze() -img_phase = img_fit.angle().squeeze() -img_sign = torch.ones_like(img_abs) -img_sign[img_phase.abs() > torch.pi / 2] = -1 -img_abs_sign = img_abs * img_sign - -# Plot coil-combined images for PEX -fig, axs = plt.subplots(2, img.shape[0], figsize=(10, 3)) -for i in range(img.shape[0]): - im1 = axs[0, i].imshow(img_abs[i, ...], cmap='gray', vmin=0, vmax=1) - im2 = axs[1, i].imshow(img_phase[i, ...], cmap='turbo', vmin=-torch.pi, vmax=torch.pi) - -fig.colorbar(im1, ax=axs[0, -1]) -fig.colorbar(im2, ax=axs[1, -1]) -plt.show() - -# Plot coil-combined images for PEX -fig, axs = plt.subplots(1, img.shape[0], figsize=(10, 3)) -for i in range(img.shape[0]): - im1 = axs[i].imshow(img_abs_sign[i, ...], cmap='seismic', vmin=-1, vmax=1) -fig.colorbar(im1, ax=axs[-1]) -plt.show() - -# plot signal of one pixel -plt.figure() -plt.plot(voltages, img_abs_sign[:, 32, 32], marker='o', linestyle='None') -plt.show() -# %% define signal model -model_op = mrpro.operators.models.PexSimple(voltages, prep_delay, t1=1, pulse_duration=pulse_duration, n_tx=8) - -# test signal model -# zero crossing should be close to 90/a -a = torch.tensor([2, 4, 10]) -(signal,) = model_op.forward(a) - -plt.figure() -plt.plot(voltages, signal[..., 2], marker='o', linestyle='None') -plt.show() - -# %% -dictionary = mrpro.operators.DictionaryMatchOp(model_op).append(torch.linspace(1, 100, 10000)) -(a_start,) = dictionary(img_abs_sign) - -# mask out data from approx 135° -voltage_mask = sqrt_kwatt_to_volt( - (3 / 4 * torch.pi) / (GYROMAGNETIC_RATIO_PROTON * 2 * torch.pi * a_start * 1e-6 * pulse_duration) -) * 8 ** (-0.5) - -weight = torch.ones(13, 64, 64) -for i in range(voltage_mask.shape[0]): - for j in range(voltage_mask.shape[1]): - voltages_tensor = torch.tensor(voltages) if isinstance(voltages, list) else voltages - weight[:, i, j] = voltages_tensor < voltage_mask[i, j] - -mse_loss = mrpro.operators.functionals.MSE(img_abs_sign, weight=weight) - -constraints_op = mrpro.operators.ConstraintsOp( - bounds=( - (0, 100), # a is constrained between 1 and 100 µT/sqrt(kW) - ) -) -functional = mse_loss @ model_op @ constraints_op -initial_parameters = constraints_op.inverse(a_start) - - -(result,) = constraints_op(*mrpro.algorithms.optimizers.lbfgs(functional, initial_parameters=initial_parameters)) -result = result.detach().cpu().squeeze() -# %% -plt.figure() -plt.imshow(result, vmin=0, vmax=60) -plt.colorbar() - -plt_idx = [20, 20] - -plt.figure() -plt.plot(voltages, img_abs_sign[:, plt_idx[0], plt_idx[1]], marker='o', linestyle='None', label='data') -plt.plot(voltages, model_op.forward(result[plt_idx[0], plt_idx[1]])[0], marker='o', linestyle='None', label='fit') -plt.plot(voltages, model_op.forward(a_start[plt_idx[0], plt_idx[1]])[0], marker='o', linestyle='None', label='initial') -plt.legend() -plt.show() - -# %% diff --git a/src/mrpro/operators/models/PEX.py b/src/mrpro/operators/models/PEX.py index 55ecd0409..d856140f7 100644 --- a/src/mrpro/operators/models/PEX.py +++ b/src/mrpro/operators/models/PEX.py @@ -9,14 +9,13 @@ from mrpro.utils.unit_conversion import GYROMAGNETIC_RATIO_PROTON, volt_to_sqrt_kwatt -class PexSimple(SignalModel[torch.Tensor, torch.Tensor]): +class PEX(SignalModel[torch.Tensor, torch.Tensor]): """Signal model for preparation based B1+ mapping (PEX).""" def __init__( self, - voltages: float | torch.Tensor | Sequence[int], + voltages: float | torch.Tensor | Sequence[float], prep_delay: float | torch.Tensor, - t1: float | torch.Tensor, pulse_duration: float | torch.Tensor, n_tx: int = 1, ) -> None: @@ -27,42 +26,41 @@ def __init__( voltages voltages. Shape `(Voltages, ...)`. prep_delay - preparation delay. Shape `(1, ...)`. - t1 - longitudinal relaxation time T1. Shape `(1, ...)`. + preparation delay. Shape `(...)`. pulse_duration - rect pulse duration in seconds. Shape `(1, ...)`. + rect pulse duration in seconds. Shape `(...)`. n_tx number of transmit channels. """ super().__init__() - voltages = torch.as_tensor(voltages) * torch.sqrt(torch.tensor(n_tx, dtype=torch.float)) + voltages = torch.as_tensor(voltages) * n_tx**0.5 prep_delay = torch.as_tensor(prep_delay) - t1 = torch.as_tensor(t1) pulse_duration = torch.as_tensor(pulse_duration) self.voltages = torch.nn.Parameter(voltages, requires_grad=voltages.requires_grad) self.prep_delay = torch.nn.Parameter(prep_delay, requires_grad=prep_delay.requires_grad) - self.t1 = torch.nn.Parameter(t1, requires_grad=t1.requires_grad) self.pulse_duration = torch.nn.Parameter(pulse_duration, requires_grad=pulse_duration.requires_grad) - def forward(self, a: torch.tensor) -> tuple[torch.Tensor,]: + def forward(self, a: torch.tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: """Apply PEX signal model. Parameters ---------- a - parameter a in µT/sqrt(kW) translating voltage of the coil to flip angle - with shape `(*other, coils, z, y, x)` + Parameter a in µT/sqrt(kW) translating voltage of the coil to flip angle + with shape `(...)`. + t1 + Longitudinal relaxation time. + Returns ------- signal with shape `(voltage, *other, coils, z, y, x)` """ ndim = a.ndim - voltages = unsqueeze_right(self.voltages, ndim - self.voltages.ndim + 1) - prep_delay = unsqueeze_right(self.prep_delay, ndim - self.prep_delay.ndim + 1) - t1 = unsqueeze_right(self.t1, ndim - self.t1.ndim + 1) - pulse_duration = unsqueeze_right(self.pulse_duration, ndim - self.pulse_duration.ndim + 1) + voltages = unsqueeze_right(self.voltages, ndim - self.voltages.ndim + 1) # +1 are voltages + prep_delay = unsqueeze_right(self.prep_delay, ndim - self.prep_delay.ndim) + t1 = unsqueeze_right(t1, ndim - t1.ndim) + pulse_duration = unsqueeze_right(self.pulse_duration, ndim - self.pulse_duration.ndim) # this is mainly cos(FA), where FA = gamma * a * voltage * t signal = 1 - ( diff --git a/src/mrpro/operators/models/__init__.py b/src/mrpro/operators/models/__init__.py index 4be8574e2..7a4ed5ed6 100644 --- a/src/mrpro/operators/models/__init__.py +++ b/src/mrpro/operators/models/__init__.py @@ -10,7 +10,7 @@ from mrpro.operators.models.cMRF import CardiacFingerprinting from mrpro.operators.models.TransientSteadyStateWithPreparation import TransientSteadyStateWithPreparation from mrpro.operators.models import EPG -from mrpro.operators.models.PEX import PexSimple +from mrpro.operators.models.PEX import PEX __all__ = [ "CardiacFingerprinting", @@ -18,10 +18,10 @@ "InversionRecovery", "MOLLI", "MonoExponentialDecay", + "PEX", "SaturationRecovery", "SpoiledGRE", "TransientSteadyStateWithPreparation", "WASABI", - "WASABITI", - "PEX" -] + "WASABITI" +] \ No newline at end of file diff --git a/src/mrpro/utils/unit_conversion.py b/src/mrpro/utils/unit_conversion.py index 57adaaf1c..0cad77ce5 100644 --- a/src/mrpro/utils/unit_conversion.py +++ b/src/mrpro/utils/unit_conversion.py @@ -115,7 +115,7 @@ def volt_to_sqrt_kwatt(volt: T) -> T: def sqrt_kwatt_to_volt(sqrt_kwatt: T) -> T: """Convert kilo Watt to Volt for 50 Ohm.""" if isinstance(sqrt_kwatt, list): - return [sqrt_kwatt_to_volt(x) for x in kwatt] + return [sqrt_kwatt_to_volt(x) for x in sqrt_kwatt] if isinstance(sqrt_kwatt, tuple): return tuple([sqrt_kwatt_to_volt(x) for x in sqrt_kwatt]) return sqrt_kwatt * 50e3 ** (0.5) diff --git a/tests/utils/test_unit_conversion.py b/tests/utils/test_unit_conversion.py index 0976472fc..afa58457a 100644 --- a/tests/utils/test_unit_conversion.py +++ b/tests/utils/test_unit_conversion.py @@ -14,6 +14,8 @@ ms_to_s, rad_to_deg, s_to_ms, + sqrt_kwatt_to_volt, + volt_to_sqrt_kwatt, ) @@ -95,3 +97,52 @@ def test_magnetic_field_to_lamor_frequency(): proton_gyromagnetic_ratio = 42.58 * 1e6 magnetic_field_strength = 3.0 assert magnetic_field_to_lamor_frequency(magnetic_field_strength, proton_gyromagnetic_ratio) == 127.74 * 1e6 + + +def test_volt_to_sqrt_kwatt(): + """Verify Volt to sqrt(kW) conversion.""" + rng = RandomGenerator(seed=0) + volt_input = rng.float32_tensor((3, 4, 5)) + expected = volt_input / (50e3**0.5) + torch.testing.assert_close(volt_to_sqrt_kwatt(volt_input), expected) + + +def test_sqrt_kwatt_to_volt(): + """Verify sqrt(kW) to Volt conversion.""" + rng = RandomGenerator(seed=0) + sqrt_kwatt_input = rng.float32_tensor((3, 4, 5)) + expected = sqrt_kwatt_input * (50e3**0.5) + torch.testing.assert_close(sqrt_kwatt_to_volt(sqrt_kwatt_input), expected) + + +def test_volt_sqrt_kwatt_round_trip(): + """Verify Volt <-> sqrt(kW) conversions are inverse operations.""" + rng = RandomGenerator(seed=42) + volt_input = rng.float32_tensor((3, 4, 5), low=1.0, high=1000.0) + + # Test round trip: volt -> sqrt_kwatt -> volt + sqrt_kwatt_converted = volt_to_sqrt_kwatt(volt_input) + volt_recovered = sqrt_kwatt_to_volt(sqrt_kwatt_converted) + torch.testing.assert_close(volt_input, volt_recovered) + + # Test round trip: sqrt_kwatt -> volt -> sqrt_kwatt + sqrt_kwatt_input = rng.float32_tensor((3, 4, 5), low=0.001, high=1.0) + volt_converted = sqrt_kwatt_to_volt(sqrt_kwatt_input) + sqrt_kwatt_recovered = volt_to_sqrt_kwatt(volt_converted) + torch.testing.assert_close(sqrt_kwatt_input, sqrt_kwatt_recovered) + + +def test_volt_to_sqrt_kwatt_scalar(): + """Verify Volt to sqrt(kW) conversion for scalar values.""" + volt = 100.0 + expected = volt / (50e3**0.5) + result = volt_to_sqrt_kwatt(volt) + assert abs(result - expected) < 1e-6 + + +def test_sqrt_kwatt_to_volt_scalar(): + """Verify sqrt(kW) to Volt conversion for scalar values.""" + sqrt_kwatt = 0.5 + expected = sqrt_kwatt * (50e3**0.5) + result = sqrt_kwatt_to_volt(sqrt_kwatt) + assert abs(result - expected) < 1e-6 From 9cc98857be47b1b6525b4adc3df8beafdd22abeb Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 17 Jun 2025 00:51:48 +0200 Subject: [PATCH 3/9] mypy --- src/mrpro/operators/models/PEX.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/mrpro/operators/models/PEX.py b/src/mrpro/operators/models/PEX.py index d856140f7..72863914d 100644 --- a/src/mrpro/operators/models/PEX.py +++ b/src/mrpro/operators/models/PEX.py @@ -33,12 +33,12 @@ def __init__( number of transmit channels. """ super().__init__() - voltages = torch.as_tensor(voltages) * n_tx**0.5 - prep_delay = torch.as_tensor(prep_delay) - pulse_duration = torch.as_tensor(pulse_duration) - self.voltages = torch.nn.Parameter(voltages, requires_grad=voltages.requires_grad) - self.prep_delay = torch.nn.Parameter(prep_delay, requires_grad=prep_delay.requires_grad) - self.pulse_duration = torch.nn.Parameter(pulse_duration, requires_grad=pulse_duration.requires_grad) + voltages_ = torch.as_tensor(voltages) * n_tx**0.5 + prep_delay_ = torch.as_tensor(prep_delay) + pulse_duration_ = torch.as_tensor(pulse_duration) + self.voltages = torch.nn.Parameter(voltages_, requires_grad=voltages_.requires_grad) + self.prep_delay = torch.nn.Parameter(prep_delay_, requires_grad=prep_delay_.requires_grad) + self.pulse_duration = torch.nn.Parameter(pulse_duration_, requires_grad=pulse_duration_.requires_grad) def forward(self, a: torch.tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: """Apply PEX signal model. From dbf3d4fff2f9ee32763174d63ae489ad3297f342 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 17 Jun 2025 00:51:48 +0200 Subject: [PATCH 4/9] mypy --- src/mrpro/operators/models/PEX.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/mrpro/operators/models/PEX.py b/src/mrpro/operators/models/PEX.py index 72863914d..3ebabea48 100644 --- a/src/mrpro/operators/models/PEX.py +++ b/src/mrpro/operators/models/PEX.py @@ -40,7 +40,7 @@ def __init__( self.prep_delay = torch.nn.Parameter(prep_delay_, requires_grad=prep_delay_.requires_grad) self.pulse_duration = torch.nn.Parameter(pulse_duration_, requires_grad=pulse_duration_.requires_grad) - def forward(self, a: torch.tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: + def __call__(self, a: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: """Apply PEX signal model. Parameters @@ -51,11 +51,19 @@ def forward(self, a: torch.tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: t1 Longitudinal relaxation time. - Returns ------- signal with shape `(voltage, *other, coils, z, y, x)` """ + return super().__call__(a, t1) + + def forward(self, a: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: + """Apply PEX signal model. + + .. note:: + Prefer calling the instance of the PEX operator as ``operator(a, t1)`` over + directly calling this method. + """ ndim = a.ndim voltages = unsqueeze_right(self.voltages, ndim - self.voltages.ndim + 1) # +1 are voltages prep_delay = unsqueeze_right(self.prep_delay, ndim - self.prep_delay.ndim) From 94d3a28fe64a48de1776e6a5ab82bdacce9c22bf Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 17 Jun 2025 01:00:55 +0200 Subject: [PATCH 5/9] update --- tests/operators/models/test_pex.py | 164 +++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 tests/operators/models/test_pex.py diff --git a/tests/operators/models/test_pex.py b/tests/operators/models/test_pex.py new file mode 100644 index 000000000..27e9832ec --- /dev/null +++ b/tests/operators/models/test_pex.py @@ -0,0 +1,164 @@ +"""Tests for PEX signal model.""" + +from collections.abc import Sequence + +import pytest +import torch +from mrpro.operators.models import PEX +from mrpro.utils import RandomGenerator +from tests import autodiff_test +from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS + + +@pytest.mark.parametrize( + ('voltages', 'pulse_duration', 'expected_behavior'), + [ + (0, 1.0, 'unity'), # zero voltage should give signal close to 1 + ([0, 100, 1000], 0, 'unity'), # zero pulse duration should give signal close to 1 + ], +) +def test_pex_special_values( + voltages: float | list[float], + pulse_duration: float, + expected_behavior: str, + parameter_shape: Sequence[int] = (2, 5, 10, 10, 10), +) -> None: + """Test PEX signal at special input values.""" + rng = RandomGenerator(0) + prep_delay = 0.01 # short prep delay + + model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) + a = rng.float32_tensor(parameter_shape, low=0.1, high=10) # µT/sqrt(kW) + t1 = rng.float32_tensor(parameter_shape, low=0.1, high=2) + + (signal,) = model(a, t1) + + # For zero voltage or zero pulse duration, signal should be close to 1 + if expected_behavior == 'unity': + expected = torch.ones_like(signal) + torch.testing.assert_close(signal, expected, atol=1e-3, rtol=1e-3) + + +def test_pex_flip_angle_behavior(parameter_shape: Sequence[int] = (2, 5, 10)) -> None: + """Test PEX signal behavior with increasing flip angles.""" + rng = RandomGenerator(1) + voltages = [10, 50, 100] + prep_delay = 0.01 + pulse_duration = 0.001 + model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) + a = rng.float32_tensor(parameter_shape, low=1, high=5) # µT/sqrt(kW) + t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2) + + (signal,) = model(a, t1) + + # Signal should decrease with increasing voltage (higher flip angles) + assert torch.all(signal[0] >= signal[1]) # first voltage < second voltage + assert torch.all(signal[1] >= signal[2]) # second voltage < third voltage + assert signal.isfinite().all() + + +def test_pex_t1_recovery(parameter_shape: Sequence[int] = (2, 5, 10)) -> None: + """Test PEX signal T1 recovery behavior.""" + rng = RandomGenerator(2) + voltages = 100 # fixed voltage + prep_delay = torch.tensor([0.001, 0.01, 0.1, 1.0]) # increasing prep delays + pulse_duration = 0.001 + + model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) + a = rng.float32_tensor(parameter_shape, low=1, high=5) + t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2) + + (signal,) = model(a, t1) + + # Signal should increase with longer prep delay (more T1 recovery) + for i in range(len(prep_delay) - 1): + assert torch.all(signal[i] <= signal[i + 1]) + assert signal.isfinite().all() + + +@SHAPE_VARIATIONS_SIGNAL_MODELS +def test_pex_shape( + parameter_shape: Sequence[int], contrast_dim_shape: Sequence[int], signal_shape: Sequence[int] +) -> None: + """Test correct signal shapes.""" + rng = RandomGenerator(1) + voltages = rng.float32_tensor(contrast_dim_shape, low=0, high=200) + prep_delay = 0.01 + pulse_duration = 0.001 + + model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) + a = rng.float32_tensor(parameter_shape, low=0.1, high=10) + t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2) + (signal,) = model(a, t1) + assert signal.shape == signal_shape + assert signal.isfinite().all() + + +def test_autodiff_pex( + parameter_shape: Sequence[int] = (2, 5, 10), + contrast_dim_shape: Sequence[int] = (13, 2, 5, 10), +) -> None: + """Test autodiff works for PEX model.""" + rng = RandomGenerator(2) + voltages = rng.float32_tensor(contrast_dim_shape, low=0, high=200) + prep_delay = 0.01 + pulse_duration = 0.001 + + model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) + a = rng.float32_tensor(parameter_shape, low=0.1, high=10) + t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2) + autodiff_test(model, a, t1) + + +@pytest.mark.cuda +def test_pex_cuda(parameter_shape: Sequence[int] = (2, 5), contrast_dim_shape: Sequence[int] = (13, 2, 5)) -> None: + """Test the PEX model works on cuda devices.""" + rng = RandomGenerator(3) + voltages = rng.float32_tensor(contrast_dim_shape, low=0, high=200) + prep_delay = 0.01 + pulse_duration = 0.001 + a = rng.float32_tensor(parameter_shape, low=0.1, high=10) + t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2) + + # Create on CPU, transfer to GPU and run on GPU + model = PEX(voltages=voltages.tolist(), prep_delay=prep_delay, pulse_duration=pulse_duration) + model.cuda() + (signal,) = model(a.cuda(), t1.cuda()) + assert signal.is_cuda + assert signal.isfinite().all() + + # Create on GPU and run on GPU + model = PEX(voltages=voltages.cuda(), prep_delay=prep_delay, pulse_duration=pulse_duration) + (signal,) = model(a.cuda(), t1.cuda()) + assert signal.is_cuda + assert signal.isfinite().all() + + # Create on GPU, transfer to CPU and run on CPU + model = PEX(voltages=voltages.cuda(), prep_delay=prep_delay, pulse_duration=pulse_duration) + model.cpu() + (signal,) = model(a, t1) + assert signal.is_cpu + assert signal.isfinite().all() + + +def test_pex_n_tx_scaling(parameter_shape: Sequence[int] = (2, 5, 10)) -> None: + """Test PEX signal scales correctly with number of transmit channels.""" + rng = RandomGenerator(4) + voltages = 100 + prep_delay = 0.01 + pulse_duration = 0.001 + a = rng.float32_tensor(parameter_shape, low=1, high=5) + t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2) + + # Test with different n_tx values + model_1tx = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration, n_tx=1) + model_4tx = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration, n_tx=4) + + (signal_1tx,) = model_1tx(a, t1) + (signal_4tx,) = model_4tx(a, t1) + + # With higher n_tx, the effective voltage is scaled by sqrt(n_tx), so flip angle increases + # This should result in lower signal values + assert torch.all(signal_4tx <= signal_1tx) + assert signal_1tx.isfinite().all() + assert signal_4tx.isfinite().all() From 49d13a998599d2fd9f2c62170d2be4905e61cfa5 Mon Sep 17 00:00:00 2001 From: mxlutz Date: Thu, 31 Jul 2025 12:56:52 +0200 Subject: [PATCH 6/9] renamed a to b1 --- src/mrpro/operators/models/PEX.py | 14 +++++++------- tests/operators/models/test_pex.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/mrpro/operators/models/PEX.py b/src/mrpro/operators/models/PEX.py index 3ebabea48..b821b77e9 100644 --- a/src/mrpro/operators/models/PEX.py +++ b/src/mrpro/operators/models/PEX.py @@ -40,13 +40,13 @@ def __init__( self.prep_delay = torch.nn.Parameter(prep_delay_, requires_grad=prep_delay_.requires_grad) self.pulse_duration = torch.nn.Parameter(pulse_duration_, requires_grad=pulse_duration_.requires_grad) - def __call__(self, a: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: + def __call__(self, b1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: """Apply PEX signal model. Parameters ---------- - a - Parameter a in µT/sqrt(kW) translating voltage of the coil to flip angle + b1 + B1+ in µT/sqrt(kW) translating voltage of the coil to flip angle with shape `(...)`. t1 Longitudinal relaxation time. @@ -55,16 +55,16 @@ def __call__(self, a: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: ------- signal with shape `(voltage, *other, coils, z, y, x)` """ - return super().__call__(a, t1) + return super().__call__(b1, t1) - def forward(self, a: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: + def forward(self, b1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: """Apply PEX signal model. .. note:: Prefer calling the instance of the PEX operator as ``operator(a, t1)`` over directly calling this method. """ - ndim = a.ndim + ndim = b1.ndim voltages = unsqueeze_right(self.voltages, ndim - self.voltages.ndim + 1) # +1 are voltages prep_delay = unsqueeze_right(self.prep_delay, ndim - self.prep_delay.ndim) t1 = unsqueeze_right(t1, ndim - t1.ndim) @@ -74,7 +74,7 @@ def forward(self, a: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: signal = 1 - ( 1 - torch.cos( - a * volt_to_sqrt_kwatt(voltages) * 1e-6 * pulse_duration * GYROMAGNETIC_RATIO_PROTON * 2 * torch.pi + b1 * volt_to_sqrt_kwatt(voltages) * 1e-6 * pulse_duration * GYROMAGNETIC_RATIO_PROTON * 2 * torch.pi ) ) * torch.exp(-prep_delay / t1) return (signal,) diff --git a/tests/operators/models/test_pex.py b/tests/operators/models/test_pex.py index 27e9832ec..e54ab1a4c 100644 --- a/tests/operators/models/test_pex.py +++ b/tests/operators/models/test_pex.py @@ -28,10 +28,10 @@ def test_pex_special_values( prep_delay = 0.01 # short prep delay model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) - a = rng.float32_tensor(parameter_shape, low=0.1, high=10) # µT/sqrt(kW) + b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10) # µT/sqrt(kW) t1 = rng.float32_tensor(parameter_shape, low=0.1, high=2) - (signal,) = model(a, t1) + (signal,) = model(b1, t1) # For zero voltage or zero pulse duration, signal should be close to 1 if expected_behavior == 'unity': From d526c946343ca07715a7af42795def019af0ada1 Mon Sep 17 00:00:00 2001 From: mxlutz Date: Thu, 31 Jul 2025 13:59:51 +0200 Subject: [PATCH 7/9] changed prep_delay dim --- src/mrpro/operators/models/PEX.py | 3 +-- tests/operators/models/test_pex.py | 30 +++++++++++++++--------------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/mrpro/operators/models/PEX.py b/src/mrpro/operators/models/PEX.py index b821b77e9..5917a4ebc 100644 --- a/src/mrpro/operators/models/PEX.py +++ b/src/mrpro/operators/models/PEX.py @@ -66,8 +66,7 @@ def forward(self, b1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: """ ndim = b1.ndim voltages = unsqueeze_right(self.voltages, ndim - self.voltages.ndim + 1) # +1 are voltages - prep_delay = unsqueeze_right(self.prep_delay, ndim - self.prep_delay.ndim) - t1 = unsqueeze_right(t1, ndim - t1.ndim) + prep_delay = unsqueeze_right(self.prep_delay, ndim - self.prep_delay.ndim + 1) pulse_duration = unsqueeze_right(self.pulse_duration, ndim - self.pulse_duration.ndim) # this is mainly cos(FA), where FA = gamma * a * voltage * t diff --git a/tests/operators/models/test_pex.py b/tests/operators/models/test_pex.py index e54ab1a4c..3dd994999 100644 --- a/tests/operators/models/test_pex.py +++ b/tests/operators/models/test_pex.py @@ -46,10 +46,10 @@ def test_pex_flip_angle_behavior(parameter_shape: Sequence[int] = (2, 5, 10)) -> prep_delay = 0.01 pulse_duration = 0.001 model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) - a = rng.float32_tensor(parameter_shape, low=1, high=5) # µT/sqrt(kW) + b1 = rng.float32_tensor(parameter_shape, low=1, high=5) # µT/sqrt(kW) t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2) - (signal,) = model(a, t1) + (signal,) = model(b1, t1) # Signal should decrease with increasing voltage (higher flip angles) assert torch.all(signal[0] >= signal[1]) # first voltage < second voltage @@ -65,10 +65,10 @@ def test_pex_t1_recovery(parameter_shape: Sequence[int] = (2, 5, 10)) -> None: pulse_duration = 0.001 model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) - a = rng.float32_tensor(parameter_shape, low=1, high=5) + b1 = rng.float32_tensor(parameter_shape, low=1, high=5) t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2) - (signal,) = model(a, t1) + (signal,) = model(b1, t1) # Signal should increase with longer prep delay (more T1 recovery) for i in range(len(prep_delay) - 1): @@ -87,9 +87,9 @@ def test_pex_shape( pulse_duration = 0.001 model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) - a = rng.float32_tensor(parameter_shape, low=0.1, high=10) + b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10) t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2) - (signal,) = model(a, t1) + (signal,) = model(b1, t1) assert signal.shape == signal_shape assert signal.isfinite().all() @@ -105,9 +105,9 @@ def test_autodiff_pex( pulse_duration = 0.001 model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) - a = rng.float32_tensor(parameter_shape, low=0.1, high=10) + b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10) t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2) - autodiff_test(model, a, t1) + autodiff_test(model, b1, t1) @pytest.mark.cuda @@ -117,26 +117,26 @@ def test_pex_cuda(parameter_shape: Sequence[int] = (2, 5), contrast_dim_shape: S voltages = rng.float32_tensor(contrast_dim_shape, low=0, high=200) prep_delay = 0.01 pulse_duration = 0.001 - a = rng.float32_tensor(parameter_shape, low=0.1, high=10) + b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10) t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2) # Create on CPU, transfer to GPU and run on GPU model = PEX(voltages=voltages.tolist(), prep_delay=prep_delay, pulse_duration=pulse_duration) model.cuda() - (signal,) = model(a.cuda(), t1.cuda()) + (signal,) = model(b1.cuda(), t1.cuda()) assert signal.is_cuda assert signal.isfinite().all() # Create on GPU and run on GPU model = PEX(voltages=voltages.cuda(), prep_delay=prep_delay, pulse_duration=pulse_duration) - (signal,) = model(a.cuda(), t1.cuda()) + (signal,) = model(b1.cuda(), t1.cuda()) assert signal.is_cuda assert signal.isfinite().all() # Create on GPU, transfer to CPU and run on CPU model = PEX(voltages=voltages.cuda(), prep_delay=prep_delay, pulse_duration=pulse_duration) model.cpu() - (signal,) = model(a, t1) + (signal,) = model(b1, t1) assert signal.is_cpu assert signal.isfinite().all() @@ -147,15 +147,15 @@ def test_pex_n_tx_scaling(parameter_shape: Sequence[int] = (2, 5, 10)) -> None: voltages = 100 prep_delay = 0.01 pulse_duration = 0.001 - a = rng.float32_tensor(parameter_shape, low=1, high=5) + b1 = rng.float32_tensor(parameter_shape, low=1, high=5) t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2) # Test with different n_tx values model_1tx = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration, n_tx=1) model_4tx = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration, n_tx=4) - (signal_1tx,) = model_1tx(a, t1) - (signal_4tx,) = model_4tx(a, t1) + (signal_1tx,) = model_1tx(b1, t1) + (signal_4tx,) = model_4tx(b1, t1) # With higher n_tx, the effective voltage is scaled by sqrt(n_tx), so flip angle increases # This should result in lower signal values From 97491bd4041328914579e0b5d381a3a6dfd1651b Mon Sep 17 00:00:00 2001 From: mxlutz Date: Thu, 31 Jul 2025 14:13:30 +0200 Subject: [PATCH 8/9] little typos --- src/mrpro/operators/models/PEX.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrpro/operators/models/PEX.py b/src/mrpro/operators/models/PEX.py index 5917a4ebc..e721a7baa 100644 --- a/src/mrpro/operators/models/PEX.py +++ b/src/mrpro/operators/models/PEX.py @@ -61,7 +61,7 @@ def forward(self, b1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: """Apply PEX signal model. .. note:: - Prefer calling the instance of the PEX operator as ``operator(a, t1)`` over + Prefer calling the instance of the PEX operator as ``operator(b1, t1)`` over directly calling this method. """ ndim = b1.ndim @@ -69,7 +69,7 @@ def forward(self, b1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: prep_delay = unsqueeze_right(self.prep_delay, ndim - self.prep_delay.ndim + 1) pulse_duration = unsqueeze_right(self.pulse_duration, ndim - self.pulse_duration.ndim) - # this is mainly cos(FA), where FA = gamma * a * voltage * t + # this is mainly cos(FA), where FA = gamma * b1 * voltage * t signal = 1 - ( 1 - torch.cos( From 0551fb299eecf004787294c6f3020de621f20385 Mon Sep 17 00:00:00 2001 From: mxlutz Date: Thu, 31 Jul 2025 14:16:42 +0200 Subject: [PATCH 9/9] added prep delay in 0th dim --- src/mrpro/operators/models/PEX.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mrpro/operators/models/PEX.py b/src/mrpro/operators/models/PEX.py index e721a7baa..2ef50db67 100644 --- a/src/mrpro/operators/models/PEX.py +++ b/src/mrpro/operators/models/PEX.py @@ -24,9 +24,9 @@ def __init__( Parameters ---------- voltages - voltages. Shape `(Voltages, ...)`. + voltages. Shape `(voltage, ...)`. prep_delay - preparation delay. Shape `(...)`. + preparation delay. Shape `(prepdelay, ...)`. pulse_duration rect pulse duration in seconds. Shape `(...)`. n_tx @@ -53,7 +53,7 @@ def __call__(self, b1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: Returns ------- - signal with shape `(voltage, *other, coils, z, y, x)` + signal with shape `(voltage/prepdelay, *other, coils, z, y, x)` """ return super().__call__(b1, t1)