-
Notifications
You must be signed in to change notification settings - Fork 11
Add PEX signal model #841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
fzimmermann89
wants to merge
12
commits into
main
Choose a base branch
from
242-pex-signal-model
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Add PEX signal model #841
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
1817365
first version, needs clean up, fit function simple for slow pex
mxlutz cccdc26
Merge branch 'main' into 242-pex-signal-model
fzimmermann89 62a4701
add tests. clean up
fzimmermann89 974c211
Merge branch 'main' into 242-pex-signal-model
fzimmermann89 9cc9885
mypy
fzimmermann89 dbf3d4f
mypy
fzimmermann89 94d3a28
update
fzimmermann89 f3f2132
Merge branch 'main' into 242-pex-signal-model
mxlutz 49d13a9
renamed a to b1
mxlutz d526c94
changed prep_delay dim
mxlutz 97491bd
little typos
mxlutz 0551fb2
added prep delay in 0th dim
mxlutz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| """Saturation recovery signal model for T1 mapping.""" | ||
|
|
||
| from collections.abc import Sequence | ||
|
|
||
| import torch | ||
|
|
||
| from mrpro.operators.SignalModel import SignalModel | ||
| from mrpro.utils import unsqueeze_right | ||
| from mrpro.utils.unit_conversion import GYROMAGNETIC_RATIO_PROTON, volt_to_sqrt_kwatt | ||
|
|
||
|
|
||
| class PEX(SignalModel[torch.Tensor, torch.Tensor]): | ||
| """Signal model for preparation based B1+ mapping (PEX).""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| voltages: float | torch.Tensor | Sequence[float], | ||
| prep_delay: float | torch.Tensor, | ||
| pulse_duration: float | torch.Tensor, | ||
| n_tx: int = 1, | ||
| ) -> None: | ||
| """Initialize preparation based B1+ mapping (PEX) signal model. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| voltages | ||
| voltages. Shape `(voltage, ...)`. | ||
| prep_delay | ||
| preparation delay. Shape `(prepdelay, ...)`. | ||
| pulse_duration | ||
| rect pulse duration in seconds. Shape `(...)`. | ||
| n_tx | ||
| number of transmit channels. | ||
| """ | ||
| super().__init__() | ||
| voltages_ = torch.as_tensor(voltages) * n_tx**0.5 | ||
| prep_delay_ = torch.as_tensor(prep_delay) | ||
| pulse_duration_ = torch.as_tensor(pulse_duration) | ||
| self.voltages = torch.nn.Parameter(voltages_, requires_grad=voltages_.requires_grad) | ||
| self.prep_delay = torch.nn.Parameter(prep_delay_, requires_grad=prep_delay_.requires_grad) | ||
| self.pulse_duration = torch.nn.Parameter(pulse_duration_, requires_grad=pulse_duration_.requires_grad) | ||
|
|
||
| def __call__(self, b1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: | ||
| """Apply PEX signal model. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| b1 | ||
| B1+ in µT/sqrt(kW) translating voltage of the coil to flip angle | ||
| with shape `(...)`. | ||
| t1 | ||
| Longitudinal relaxation time. | ||
|
|
||
| Returns | ||
| ------- | ||
| signal with shape `(voltage/prepdelay, *other, coils, z, y, x)` | ||
| """ | ||
| return super().__call__(b1, t1) | ||
|
|
||
| def forward(self, b1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: | ||
| """Apply PEX signal model. | ||
|
|
||
| .. note:: | ||
| Prefer calling the instance of the PEX operator as ``operator(b1, t1)`` over | ||
| directly calling this method. | ||
| """ | ||
| ndim = b1.ndim | ||
| voltages = unsqueeze_right(self.voltages, ndim - self.voltages.ndim + 1) # +1 are voltages | ||
| prep_delay = unsqueeze_right(self.prep_delay, ndim - self.prep_delay.ndim + 1) | ||
| pulse_duration = unsqueeze_right(self.pulse_duration, ndim - self.pulse_duration.ndim) | ||
|
|
||
| # this is mainly cos(FA), where FA = gamma * b1 * voltage * t | ||
| signal = 1 - ( | ||
| 1 | ||
| - torch.cos( | ||
| b1 * volt_to_sqrt_kwatt(voltages) * 1e-6 * pulse_duration * GYROMAGNETIC_RATIO_PROTON * 2 * torch.pi | ||
| ) | ||
| ) * torch.exp(-prep_delay / t1) | ||
| return (signal,) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| """Tests for PEX signal model.""" | ||
|
|
||
| from collections.abc import Sequence | ||
|
|
||
| import pytest | ||
| import torch | ||
| from mrpro.operators.models import PEX | ||
| from mrpro.utils import RandomGenerator | ||
| from tests import autodiff_test | ||
| from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ('voltages', 'pulse_duration', 'expected_behavior'), | ||
| [ | ||
| (0, 1.0, 'unity'), # zero voltage should give signal close to 1 | ||
| ([0, 100, 1000], 0, 'unity'), # zero pulse duration should give signal close to 1 | ||
| ], | ||
| ) | ||
| def test_pex_special_values( | ||
| voltages: float | list[float], | ||
| pulse_duration: float, | ||
| expected_behavior: str, | ||
| parameter_shape: Sequence[int] = (2, 5, 10, 10, 10), | ||
| ) -> None: | ||
| """Test PEX signal at special input values.""" | ||
| rng = RandomGenerator(0) | ||
| prep_delay = 0.01 # short prep delay | ||
|
|
||
| model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) | ||
| b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10) # µT/sqrt(kW) | ||
| t1 = rng.float32_tensor(parameter_shape, low=0.1, high=2) | ||
|
|
||
| (signal,) = model(b1, t1) | ||
|
|
||
| # For zero voltage or zero pulse duration, signal should be close to 1 | ||
| if expected_behavior == 'unity': | ||
| expected = torch.ones_like(signal) | ||
| torch.testing.assert_close(signal, expected, atol=1e-3, rtol=1e-3) | ||
|
|
||
|
|
||
| def test_pex_flip_angle_behavior(parameter_shape: Sequence[int] = (2, 5, 10)) -> None: | ||
| """Test PEX signal behavior with increasing flip angles.""" | ||
| rng = RandomGenerator(1) | ||
| voltages = [10, 50, 100] | ||
| prep_delay = 0.01 | ||
| pulse_duration = 0.001 | ||
| model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) | ||
| b1 = rng.float32_tensor(parameter_shape, low=1, high=5) # µT/sqrt(kW) | ||
| t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2) | ||
|
|
||
| (signal,) = model(b1, t1) | ||
|
|
||
| # Signal should decrease with increasing voltage (higher flip angles) | ||
| assert torch.all(signal[0] >= signal[1]) # first voltage < second voltage | ||
| assert torch.all(signal[1] >= signal[2]) # second voltage < third voltage | ||
| assert signal.isfinite().all() | ||
|
|
||
|
|
||
| def test_pex_t1_recovery(parameter_shape: Sequence[int] = (2, 5, 10)) -> None: | ||
| """Test PEX signal T1 recovery behavior.""" | ||
| rng = RandomGenerator(2) | ||
| voltages = 100 # fixed voltage | ||
| prep_delay = torch.tensor([0.001, 0.01, 0.1, 1.0]) # increasing prep delays | ||
| pulse_duration = 0.001 | ||
|
|
||
| model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) | ||
| b1 = rng.float32_tensor(parameter_shape, low=1, high=5) | ||
| t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2) | ||
|
|
||
| (signal,) = model(b1, t1) | ||
|
|
||
| # Signal should increase with longer prep delay (more T1 recovery) | ||
| for i in range(len(prep_delay) - 1): | ||
| assert torch.all(signal[i] <= signal[i + 1]) | ||
| assert signal.isfinite().all() | ||
|
|
||
|
|
||
| @SHAPE_VARIATIONS_SIGNAL_MODELS | ||
| def test_pex_shape( | ||
| parameter_shape: Sequence[int], contrast_dim_shape: Sequence[int], signal_shape: Sequence[int] | ||
| ) -> None: | ||
| """Test correct signal shapes.""" | ||
| rng = RandomGenerator(1) | ||
| voltages = rng.float32_tensor(contrast_dim_shape, low=0, high=200) | ||
| prep_delay = 0.01 | ||
| pulse_duration = 0.001 | ||
|
|
||
| model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) | ||
| b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10) | ||
| t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2) | ||
| (signal,) = model(b1, t1) | ||
| assert signal.shape == signal_shape | ||
| assert signal.isfinite().all() | ||
|
|
||
|
|
||
| def test_autodiff_pex( | ||
| parameter_shape: Sequence[int] = (2, 5, 10), | ||
| contrast_dim_shape: Sequence[int] = (13, 2, 5, 10), | ||
| ) -> None: | ||
| """Test autodiff works for PEX model.""" | ||
| rng = RandomGenerator(2) | ||
| voltages = rng.float32_tensor(contrast_dim_shape, low=0, high=200) | ||
| prep_delay = 0.01 | ||
| pulse_duration = 0.001 | ||
|
|
||
| model = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration) | ||
| b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10) | ||
| t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2) | ||
| autodiff_test(model, b1, t1) | ||
|
|
||
|
|
||
| @pytest.mark.cuda | ||
| def test_pex_cuda(parameter_shape: Sequence[int] = (2, 5), contrast_dim_shape: Sequence[int] = (13, 2, 5)) -> None: | ||
| """Test the PEX model works on cuda devices.""" | ||
| rng = RandomGenerator(3) | ||
| voltages = rng.float32_tensor(contrast_dim_shape, low=0, high=200) | ||
| prep_delay = 0.01 | ||
| pulse_duration = 0.001 | ||
| b1 = rng.float32_tensor(parameter_shape, low=0.1, high=10) | ||
| t1 = rng.float32_tensor(parameter_shape, low=0.01, high=2) | ||
|
|
||
| # Create on CPU, transfer to GPU and run on GPU | ||
| model = PEX(voltages=voltages.tolist(), prep_delay=prep_delay, pulse_duration=pulse_duration) | ||
| model.cuda() | ||
| (signal,) = model(b1.cuda(), t1.cuda()) | ||
| assert signal.is_cuda | ||
| assert signal.isfinite().all() | ||
|
|
||
| # Create on GPU and run on GPU | ||
| model = PEX(voltages=voltages.cuda(), prep_delay=prep_delay, pulse_duration=pulse_duration) | ||
| (signal,) = model(b1.cuda(), t1.cuda()) | ||
| assert signal.is_cuda | ||
| assert signal.isfinite().all() | ||
|
|
||
| # Create on GPU, transfer to CPU and run on CPU | ||
| model = PEX(voltages=voltages.cuda(), prep_delay=prep_delay, pulse_duration=pulse_duration) | ||
| model.cpu() | ||
| (signal,) = model(b1, t1) | ||
| assert signal.is_cpu | ||
| assert signal.isfinite().all() | ||
|
|
||
|
|
||
| def test_pex_n_tx_scaling(parameter_shape: Sequence[int] = (2, 5, 10)) -> None: | ||
| """Test PEX signal scales correctly with number of transmit channels.""" | ||
| rng = RandomGenerator(4) | ||
| voltages = 100 | ||
| prep_delay = 0.01 | ||
| pulse_duration = 0.001 | ||
| b1 = rng.float32_tensor(parameter_shape, low=1, high=5) | ||
| t1 = rng.float32_tensor(parameter_shape, low=0.5, high=2) | ||
|
|
||
| # Test with different n_tx values | ||
| model_1tx = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration, n_tx=1) | ||
| model_4tx = PEX(voltages=voltages, prep_delay=prep_delay, pulse_duration=pulse_duration, n_tx=4) | ||
|
|
||
| (signal_1tx,) = model_1tx(b1, t1) | ||
| (signal_4tx,) = model_4tx(b1, t1) | ||
|
|
||
| # With higher n_tx, the effective voltage is scaled by sqrt(n_tx), so flip angle increases | ||
| # This should result in lower signal values | ||
| assert torch.all(signal_4tx <= signal_1tx) | ||
| assert signal_1tx.isfinite().all() | ||
| assert signal_4tx.isfinite().all() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me the test_pex_t1_recovery Test was failing due to dimension issues in the signal function (prep_delay and t1 are arrays). This is fixed with moving this to the same dimension as voltages, however I am not sure whether this is a good idea? Usually one of them should only be scalar, could add some assert statement.