From d7ff10871b3776ae330e12a1c10146a04d868a00 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 3 May 2026 20:05:31 +0200 Subject: [PATCH] Add BMC model and tests --- src/mrpro/operators/models/BMC.py | 1127 ++++++++++++++++++++++++ src/mrpro/operators/models/__init__.py | 4 +- src/mrpro/utils/slice_profiles.py | 114 ++- tests/operators/models/test_bmc.py | 1009 +++++++++++++++++++++ 4 files changed, 2252 insertions(+), 2 deletions(-) create mode 100644 src/mrpro/operators/models/BMC.py create mode 100644 tests/operators/models/test_bmc.py diff --git a/src/mrpro/operators/models/BMC.py b/src/mrpro/operators/models/BMC.py new file mode 100644 index 000000000..03b9b28b0 --- /dev/null +++ b/src/mrpro/operators/models/BMC.py @@ -0,0 +1,1127 @@ +"""Bloch-McConnell simulation.""" + +from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass + +import torch +from torch.utils.checkpoint import checkpoint as activation_checkpoint + +from mrpro.data.Dataclass import Dataclass +from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.utils.reshape import unsqueeze_left, unsqueeze_right +from mrpro.utils.slice_profiles import SliceProfileBase +from mrpro.utils.TensorAttributeMixin import TensorAttributeMixin +from mrpro.utils.unit_conversion import GYROMAGNETIC_RATIO_PROTON + + +@dataclass +class MTSaturation(ABC, TensorAttributeMixin): + """Base class for MT lineshape models.""" + + pool_index: int + """Index of the MT pool in the pool dimension.""" + + t2: torch.Tensor + """Transverse relaxation time in seconds.""" + + @abstractmethod + def __call__(self, delta_omega: torch.Tensor) -> torch.Tensor: + r"""Evaluate \(G(\Delta)\) [s].""" + + +@dataclass +class LorentzianMT(MTSaturation): + """Lorentzian lineshape for MT saturation.""" + + def __call__(self, delta_omega: torch.Tensor) -> torch.Tensor: + r"""Evaluate \(G(\Delta)\) [s]. + + Parameters + ---------- + delta_omega + Detuning in rad/s. + + Returns + ------- + g + Lineshape value in seconds. + """ + t2 = self.t2.to(delta_omega) + x = delta_omega * t2 + return t2 / (1 + x * x) + + +@dataclass +class SuperLorentzianMT(MTSaturation): + """Super-Lorentzian lineshape for MT saturation.""" + + samples: int = 101 + """Quadrature samples for numerical integration.""" + + def __call__(self, delta_omega: torch.Tensor) -> torch.Tensor: + r"""Evaluate \(G(\Delta)\) [s]. + + Parameters + ---------- + delta_omega + Detuning in rad/s. + t2 + Transverse relaxation time in seconds. + """ + t2 = self.t2.to(delta_omega) + u = torch.linspace(0.0, 1.0, self.samples, device=delta_omega.device, dtype=delta_omega.dtype) + du = u[1] - u[0] + denom = (3 * u * u - 1).abs().clamp_min(1e-12) + x = (delta_omega[..., None] * t2[..., None]) / denom + integrand = (2.0 / torch.pi) ** 0.5 * t2[..., None] / denom * torch.exp(-2 * x * x) + return integrand.sum(dim=-1) * torch.pi * du + + +class Parameters(Dataclass): + """Parameters for Bloch-McConnell simulation. + + Shapes + ------ + - poolwise: ``(..., pools)`` + + Notes + ----- + Hyperpolarization is handled by setting a non-equilibrium initial state + (via ``initial_state(..., mz=...)`` or ``ResetBlock(state=...)``) and by + choosing ``equilibrium_magnetization`` appropriately. + """ + + equilibrium_magnetization: torch.Tensor + """Equilibrium magnetization.""" + t1: torch.Tensor + """T1 relaxation time in seconds. + Shape ``(..., pools)``. + """ + t2: torch.Tensor + """T2 relaxation time in seconds. Shape ``(..., pools)``.""" + + exchange_rate: torch.Tensor + """Exchange rate in 1/s. + Shape ``(..., pools, pools)`` + where element ``[..., i, j]`` is the ratefrom pool j to pool i.""" + + chemical_shift: torch.Tensor | None = None + """Chemical shift in Hz. Shape ``(..., pools)``.""" + + static_off_resonance: torch.Tensor | None = None + """Delta B0 in rad/s. Shape ``(...)`` (global per voxel/batch).""" + + relative_b1: torch.Tensor | None = None + """Relative B1 scaling factor. Shape ``(...)`` (global per voxel/batch).""" + + mt_saturation: MTSaturation | None = None + """MT saturation model. Shape ``(..., pools)``. """ + + @property + def n_pools(self) -> int: + """Number of pools.""" + return int(self.equilibrium_magnetization.shape[-1]) + + @property + def ndim(self) -> int: + """Broadcast ndim of parameter batch dimensions.""" + ndim = max( + self.equilibrium_magnetization.ndim, + self.t1.ndim, + self.t2.ndim, + self.exchange_rate.ndim - 1, + ) + if self.chemical_shift is not None: + ndim = max(ndim, self.chemical_shift.ndim) + if self.static_off_resonance is not None: + ndim = max(ndim, self.static_off_resonance.ndim + 1) + if self.relative_b1 is not None: + ndim = max(ndim, self.relative_b1.ndim + 1) + return ndim + + +def system_recovery_vector(parameters: Parameters) -> torch.Tensor: + """Build the affine recovery vector.""" + m0, t1 = parameters.equilibrium_magnetization, parameters.t1 + batch_shape = torch.broadcast_shapes( + m0.shape[:-1], + t1.shape[:-1], + parameters.t2.shape[:-1], + parameters.exchange_rate.shape[:-2], + ) + if parameters.chemical_shift is not None: + batch_shape = torch.broadcast_shapes(batch_shape, parameters.chemical_shift.shape[:-1]) + if parameters.static_off_resonance is not None: + batch_shape = torch.broadcast_shapes(batch_shape, parameters.static_off_resonance.shape) + if parameters.relative_b1 is not None: + batch_shape = torch.broadcast_shapes(batch_shape, parameters.relative_b1.shape) + c = torch.zeros(*batch_shape, 3 * parameters.n_pools, device=m0.device, dtype=m0.dtype) + c[..., 2 * parameters.n_pools :] = (1.0 / t1) * m0 + return c + + +def initial_state(parameters: Parameters, mz: torch.Tensor | None = None) -> torch.Tensor: + """Create an initial magnetization state. + + Parameters + ---------- + parameters + Simulation parameters. + mz + Optional initial longitudinal magnetization, shape ``(..., pools)``. + If omitted, uses equilibrium_magnetization. + + Returns + ------- + state + Tensor with shape ``(..., isochromats=1, pools, 3)`` holding ``(Mx, My, Mz)``. + """ + mz0 = parameters.equilibrium_magnetization if mz is None else mz + mz0 = mz0.to(parameters.equilibrium_magnetization) + z = torch.zeros_like(mz0) + return torch.stack([z, z, mz0], dim=-1).unsqueeze(-3) + + +def exchange_generator(exchange_rate: torch.Tensor) -> torch.Tensor: + r"""Construct exchange generator \(Q\) for \(dM/dt = Q M\). + + Parameters + ---------- + exchange_rate + Shape ``(..., pools, pools)`` with element ``[..., i, j]`` the rate + from pool j to pool i. + + Returns + ------- + q + Shape ``(..., pools, pools)``. + """ + exchange_rate = torch.as_tensor(exchange_rate) + out_rate = exchange_rate.sum(dim=-2) + return exchange_rate - torch.diag_embed(out_rate) + + +def system_base_matrix( + parameters: Parameters, + rf_frequency: torch.Tensor | float, + extra_off_resonance: torch.Tensor | float = 0.0, +) -> torch.Tensor: + """Build the RF-amplitude independent Bloch-McConnell matrix.""" + m0, t1, t2, exchange = parameters.equilibrium_magnetization, parameters.t1, parameters.t2, parameters.exchange_rate + freq = torch.as_tensor(rf_frequency, device=m0.device, dtype=m0.dtype) + extra_dw = torch.as_tensor(extra_off_resonance, device=m0.device, dtype=m0.dtype) + + if parameters.chemical_shift is not None: + shift = parameters.chemical_shift.to(m0) + else: + shift = m0.new_zeros(*m0.shape[:-1], parameters.n_pools) + + if parameters.static_off_resonance is not None: + dw0 = parameters.static_off_resonance.to(m0) + else: + dw0 = m0.new_zeros(m0.shape[:-1]) + + batch = torch.broadcast_shapes( + m0.shape[:-1], + t1.shape[:-1], + t2.shape[:-1], + exchange.shape[:-2], + freq.shape, + extra_dw.shape, + shift.shape[:-1], + dw0.shape, + ) + t1 = torch.broadcast_to(t1, (*batch, parameters.n_pools)) + t2 = torch.broadcast_to(t2, (*batch, parameters.n_pools)) + exchange = torch.broadcast_to(exchange, (*batch, parameters.n_pools, parameters.n_pools)) + shift = torch.broadcast_to(shift, (*batch, parameters.n_pools)) + dw0 = torch.broadcast_to(dw0, batch) + freq = torch.broadcast_to(freq, batch) + extra_dw = torch.broadcast_to(extra_dw, batch) + + r1 = 1.0 / t1 + r2 = 1.0 / t2 + + qz = exchange_generator(exchange) + qxy = qz + if parameters.mt_saturation is not None: + if not (0 <= parameters.mt_saturation.pool_index < parameters.n_pools): + raise ValueError('mt_saturation.pool_index out of bounds.') + qxy = qz.clone() + qxy[..., parameters.mt_saturation.pool_index, :] = 0 + qxy[..., :, parameters.mt_saturation.pool_index] = 0 + + delta_omega = dw0[..., None] + extra_dw[..., None] - 2 * torch.pi * freq[..., None] + 2 * torch.pi * shift + + a_xx = qxy - torch.diag_embed(r2) + a_zz = qz - torch.diag_embed(r1) + a_xy = -torch.diag_embed(delta_omega) + + n = 3 * parameters.n_pools + matrix = torch.zeros(*batch, n, n, device=m0.device, dtype=m0.dtype) + matrix[..., : parameters.n_pools, : parameters.n_pools] = a_xx + matrix[..., parameters.n_pools : 2 * parameters.n_pools, parameters.n_pools : 2 * parameters.n_pools] = a_xx + matrix[..., 2 * parameters.n_pools :, 2 * parameters.n_pools :] = a_zz + matrix[..., : parameters.n_pools, parameters.n_pools : 2 * parameters.n_pools] += a_xy + matrix[..., parameters.n_pools : 2 * parameters.n_pools, : parameters.n_pools] -= a_xy + return matrix + + +def gradient_to_extra_off_resonance( + gradient_z: torch.Tensor | float | None, + gradient_y: torch.Tensor | float | None, + gradient_x: torch.Tensor | float | None, + positions: SpatialDimension[torch.Tensor], +) -> torch.Tensor: + """Convert gradients and isochromat positions to extra off-resonance in rad/s.""" + device = positions.z.device + gradient_z = torch.as_tensor(0.0 if gradient_z is None else gradient_z, device=device) + gradient_y = torch.as_tensor(0.0 if gradient_y is None else gradient_y, device=device) + gradient_x = torch.as_tensor(0.0 if gradient_x is None else gradient_x, device=device) + px = positions.x.to(device) + py = positions.y.to(device) + pz = positions.z.to(device) + return ( + 2 + * torch.pi + * GYROMAGNETIC_RATIO_PROTON + * (gradient_z[..., None] * pz + gradient_y[..., None] * py + gradient_x[..., None] * px) + ) + + +def system_rf_matrix( + parameters: Parameters, + rf_amplitude: torch.Tensor | float, + rf_phase: torch.Tensor | float, + rf_frequency: torch.Tensor | float, + extra_off_resonance: torch.Tensor | float = 0.0, +) -> torch.Tensor: + """Build the RF-amplitude dependent Bloch-McConnell matrix contribution.""" + m0 = parameters.equilibrium_magnetization + + amp = torch.as_tensor(rf_amplitude, device=m0.device, dtype=m0.dtype) + phase = torch.as_tensor(rf_phase, device=m0.device, dtype=m0.dtype) + freq = torch.as_tensor(rf_frequency, device=m0.device, dtype=m0.dtype) + extra_dw = torch.as_tensor(extra_off_resonance, device=m0.device, dtype=m0.dtype) + + if parameters.relative_b1 is not None: + rb1 = parameters.relative_b1.to(amp) + if rb1.is_complex(): + phase = phase + rb1.angle() + amp = amp * rb1.abs() + else: + amp = amp * rb1 + + if parameters.chemical_shift is not None: + shift = parameters.chemical_shift.to(m0) + else: + shift = m0.new_zeros(*m0.shape[:-1], parameters.n_pools) + + if parameters.static_off_resonance is not None: + dw0 = parameters.static_off_resonance.to(m0) + else: + dw0 = m0.new_zeros(m0.shape[:-1]) + + batch = torch.broadcast_shapes( + m0.shape[:-1], + amp.shape, + phase.shape, + freq.shape, + extra_dw.shape, + shift.shape[:-1], + dw0.shape, + ) + shift = torch.broadcast_to(shift, (*batch, parameters.n_pools)) + dw0 = torch.broadcast_to(dw0, batch) + amp = torch.broadcast_to(amp, batch) + phase = torch.broadcast_to(phase, batch) + freq = torch.broadcast_to(freq, batch) + extra_dw = torch.broadcast_to(extra_dw, batch) + + delta_omega = dw0[..., None] + extra_dw[..., None] - 2 * torch.pi * freq[..., None] + 2 * torch.pi * shift + + w1 = 2 * torch.pi * amp + w1x = w1 * torch.cos(phase) + w1y = w1 * torch.sin(phase) + + eye_rf = torch.eye(parameters.n_pools, device=m0.device, dtype=m0.dtype) + if parameters.mt_saturation is not None: + eye_rf = eye_rf.clone() + eye_rf[parameters.mt_saturation.pool_index, parameters.mt_saturation.pool_index] = 0.0 + + a_xz = -w1y[..., None, None] * eye_rf + a_yz = w1x[..., None, None] * eye_rf + + n = 3 * parameters.n_pools + matrix = torch.zeros(*batch, n, n, device=m0.device, dtype=m0.dtype) + matrix[..., : parameters.n_pools, 2 * parameters.n_pools :] += a_xz + matrix[..., parameters.n_pools : 2 * parameters.n_pools, 2 * parameters.n_pools :] += a_yz + matrix[..., 2 * parameters.n_pools :, : parameters.n_pools] -= a_xz + matrix[..., 2 * parameters.n_pools :, parameters.n_pools : 2 * parameters.n_pools] -= a_yz + + if parameters.mt_saturation is not None: + g = parameters.mt_saturation(delta_omega[..., parameters.mt_saturation.pool_index]) + one_hot = torch.zeros(parameters.n_pools, device=m0.device, dtype=m0.dtype) + one_hot[parameters.mt_saturation.pool_index] = 1.0 + mt_diag = torch.diag_embed(((w1 * w1) * g)[..., None] * one_hot) + matrix[..., 2 * parameters.n_pools :, 2 * parameters.n_pools :] -= mt_diag + + return matrix + + +def system_matrix( + parameters: Parameters, + rf_amplitude: torch.Tensor | float, + rf_phase: torch.Tensor | float, + rf_frequency: torch.Tensor | float, + extra_off_resonance: torch.Tensor | float = 0.0, +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Build affine Bloch-McConnell system \(dm/dt = A m + c\). + + Parameters + ---------- + parameters + Simulation parameters. + rf_amplitude + RF amplitude in Hz, broadcastable to batch. Shape ``(...)``. + rf_phase + RF phase in rad, broadcastable to batch. Shape ``(...)``. + rf_frequency + RF carrier offset in Hz, broadcastable to batch. Shape ``(...)``. + extra_off_resonance + Additional off-resonance in rad/s, broadcastable to batch. Shape ``(...)``. + + Returns + ------- + A + System matrix with shape ``(..., 3*pools, 3*pools)``. + c + Inhomogeneity vector with shape ``(..., 3*pools)``. + """ + matrix = system_base_matrix(parameters, rf_frequency, extra_off_resonance) + system_rf_matrix( + parameters, rf_amplitude, rf_phase, rf_frequency, extra_off_resonance + ) + c = system_recovery_vector(parameters) + return matrix, c + + +def propagate( + state: torch.Tensor, + matrix: torch.Tensor, + c: torch.Tensor, + duration: torch.Tensor | float, +) -> torch.Tensor: + r"""Propagate dynamics \(dm/dt = A m + c\) via exact affine evolution.""" + matrix = matrix.unsqueeze(-3) + c = c.unsqueeze(-2) + duration = torch.as_tensor(duration, device=matrix.device, dtype=matrix.dtype) + step = propagation_step(matrix, c, duration) + return apply_propagation_step(state, step) + + +def propagation_step( + matrix: torch.Tensor, + c: torch.Tensor, + duration: torch.Tensor | float, +) -> torch.Tensor: + """Build exact affine propagation steps for constant-system evolution.""" + duration = torch.as_tensor(duration, device=matrix.device, dtype=matrix.dtype) + duration = unsqueeze_right(duration, matrix.ndim - duration.ndim) + linear_step = torch.matrix_exp(matrix * duration) + identity = torch.eye(matrix.shape[-1], device=matrix.device, dtype=matrix.dtype) + offset_rhs = ((linear_step - identity) @ c.unsqueeze(-1)).squeeze(-1) + offset, info = torch.linalg.solve_ex(matrix, offset_rhs.unsqueeze(-1)) + if torch.any(info != 0): + augmented = torch.zeros( + *matrix.shape[:-2], + matrix.shape[-1] + 1, + matrix.shape[-1] + 1, + device=matrix.device, + dtype=matrix.dtype, + ) + augmented[..., :-1, :-1] = matrix + augmented[..., :-1, -1] = c + augmented_step = torch.matrix_exp(augmented * duration) + return augmented_step[..., :-1, :] + return torch.cat([linear_step, offset], dim=-1) + + +def apply_propagation_step(state: torch.Tensor, step: torch.Tensor) -> torch.Tensor: + """Apply a precomputed exact affine propagation step to a state.""" + pools = int(state.shape[-2]) + isochromats = int(state.shape[-3]) + n = 3 * pools + if state.shape[:-3] == step.shape[:-3]: + batch = state.shape[:-3] + else: + batch = torch.broadcast_shapes(state.shape[:-3], step.shape[:-3]) + state = torch.broadcast_to(state, (*batch, isochromats, pools, 3)) + step = torch.broadcast_to(step, (*batch, isochromats, n, n + 1)) + m = state.mT.reshape(*batch, isochromats, n) + linear_step, offset = step[..., :n], step[..., n] + m_next = (linear_step @ m.unsqueeze(-1)).squeeze(-1) + offset + return m_next.reshape(*batch, isochromats, 3, pools).mT + + +def transverse_readout(state: torch.Tensor) -> torch.Tensor: + """Complex transverse readout per pool, averaged over isochromats.""" + return torch.complex(state[..., 0], state[..., 1]).mean(dim=-2) + + +class BMCBlock(TensorAttributeMixin, ABC): + """Base class for Bloch-McConnell blocks.""" + + def __call__( + self, + parameters: Parameters, + state: torch.Tensor | None = None, + *, + zero_matrix: torch.Tensor | None = None, + zero_c: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + """Apply the block.""" + if state is None: + state = initial_state(parameters) + if zero_matrix is None and zero_c is None: + return super().__call__(parameters, state) + return super().__call__(parameters, state, zero_matrix=zero_matrix, zero_c=zero_c) + + @abstractmethod + def forward(self, parameters: Parameters, state: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + """Apply the block.""" + raise NotImplementedError + + @property + def duration(self) -> torch.Tensor: + """Duration of the block.""" + return torch.as_tensor(0.0) + + +class ConstantRFBlock(BMCBlock): + """Constant RF block for a duration.""" + + def __init__( + self, + duration: torch.Tensor | float, + rf_amplitude: torch.Tensor | float, + rf_phase: torch.Tensor | float = 0.0, + rf_frequency: torch.Tensor | float = 0.0, + ) -> None: + """Initialize the block. + + Parameters + ---------- + duration + Duration in seconds. + rf_amplitude + RF amplitude in Hz. + rf_phase + RF phase in rad. + rf_frequency + RF frequency in Hz. + """ + super().__init__() + self._duration = torch.as_tensor(duration) + self.rf_amplitude = torch.as_tensor(rf_amplitude) + self.rf_phase = torch.as_tensor(rf_phase) + self.rf_frequency = torch.as_tensor(rf_frequency) + + @property + def duration(self) -> torch.Tensor: + """Duration of the block.""" + return self._duration + + def forward(self, parameters: Parameters, state: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + """Apply the block.""" + matrix, c = system_matrix( + parameters, + self.rf_amplitude.to(state), + self.rf_phase.to(state), + self.rf_frequency.to(state), + ) + state = propagate(state, matrix, c, self._duration.to(state)) + return state, () + + +class PiecewiseRFBlock(BMCBlock): + """Piecewise-constant RF block.""" + + def __init__( + self, + rf_amplitude: torch.Tensor, + rf_phase: torch.Tensor | float = 0.0, + rf_frequency: torch.Tensor | float = 0.0, + dt: torch.Tensor | float = 0.0, + extra_off_resonance: torch.Tensor | float = 0.0, + ) -> None: + """Initialize the block. + + Parameters + ---------- + rf_amplitude + RF amplitude. Shape ``(time, ...)``. + rf_phase + RF phase in rad. Shape ``(time, ...)``, ``(1, ...)`` or scalar. + rf_frequency + RF frequency in Hz. Shape ``(time, ...)``, ``(1, ...)`` or scalar. + dt + Sample duration in seconds. Shape ``(time, ...)``, ``(1, ...)`` or scalar. + extra_off_resonance + Additional off-resonance in rad/s. Shape ``(time, ...)``, ``(1, ...)`` or scalar. + """ + super().__init__() + self.rf_amplitude = torch.as_tensor(rf_amplitude) + self.rf_phase = torch.as_tensor(rf_phase) + self.rf_frequency = torch.as_tensor(rf_frequency) + self.dt = torch.as_tensor(dt) + self.extra_off_resonance = torch.as_tensor(extra_off_resonance) + + if self.rf_amplitude.ndim < 1: + raise ValueError('rf_amplitude must have a leading time dimension.') + + @property + def duration(self) -> torch.Tensor: + """Duration of the block.""" + if self.dt.ndim == 0: + return self.dt * self.rf_amplitude.shape[0] + if self.dt.shape[0] == 1: + return self.dt.squeeze(0) * self.rf_amplitude.shape[0] + return self.dt.sum(dim=0) + + def forward(self, parameters: Parameters, state: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + """Apply the block.""" + rf_amplitude = self.rf_amplitude + rf_phase = unsqueeze_left(self.rf_phase, max(0, 1 - self.rf_phase.ndim)) + rf_frequency = unsqueeze_left(self.rf_frequency, max(0, 1 - self.rf_frequency.ndim)) + dt = unsqueeze_left(self.dt, max(0, 1 - self.dt.ndim)) + extra_off_resonance = unsqueeze_left(self.extra_off_resonance, max(0, 1 - self.extra_off_resonance.ndim)) + + time = rf_amplitude.shape[0] + for name, tensor in ( + ('rf_phase', rf_phase), + ('rf_frequency', rf_frequency), + ('dt', dt), + ('extra_off_resonance', extra_off_resonance), + ): + if tensor.shape[0] not in (1, time): + raise ValueError(f'{name} must have leading dimension 1 or match rf_amplitude.') + + tensor_ndim = max( + rf_amplitude.ndim, + rf_phase.ndim, + rf_frequency.ndim, + dt.ndim, + extra_off_resonance.ndim, + state.ndim - 2, + parameters.ndim, + ) + + rf_amplitude = unsqueeze_right(rf_amplitude, tensor_ndim - rf_amplitude.ndim) + rf_phase = unsqueeze_right(rf_phase, tensor_ndim - rf_phase.ndim) + rf_frequency = unsqueeze_right(rf_frequency, tensor_ndim - rf_frequency.ndim) + dt = unsqueeze_right(dt, tensor_ndim - dt.ndim) + extra_off_resonance = unsqueeze_right(extra_off_resonance, tensor_ndim - extra_off_resonance.ndim) + + if time != 1: + if rf_phase.shape[0] == 1: + rf_phase = rf_phase.expand(time, *rf_phase.shape[1:]) + if rf_frequency.shape[0] == 1: + rf_frequency = rf_frequency.expand(time, *rf_frequency.shape[1:]) + if dt.shape[0] == 1: + dt = dt.expand(time, *dt.shape[1:]) + if extra_off_resonance.shape[0] == 1: + extra_off_resonance = extra_off_resonance.expand(time, *extra_off_resonance.shape[1:]) + + c = system_recovery_vector(parameters) + same_frequency = bool(torch.all(rf_frequency == rf_frequency[:1])) + same_extra = bool(torch.all(extra_off_resonance == extra_off_resonance[:1])) + base_matrix = ( + system_base_matrix(parameters, rf_frequency[0], extra_off_resonance[0]) + if same_frequency and same_extra + else None + ) + + batch_shape = torch.broadcast_shapes( + rf_amplitude.shape[1:], + rf_phase.shape[1:], + rf_frequency.shape[1:], + dt.shape[1:], + extra_off_resonance.shape[1:], + ) + batch_size = 1 + for size in batch_shape: + batch_size *= size + work = time * batch_size * (3 * parameters.n_pools) ** 2 + chunk_size = time if work <= 2_500_000 else (128 if same_frequency else 64) + + def run_chunk( + state: torch.Tensor, + rf_amplitude: torch.Tensor, + rf_phase: torch.Tensor, + rf_frequency: torch.Tensor, + dt_chunk: torch.Tensor, + extra_off_resonance: torch.Tensor, + ) -> torch.Tensor: + if same_frequency: + if base_matrix is not None: + matrices = base_matrix + system_rf_matrix( + parameters, + rf_amplitude, + rf_phase, + rf_frequency, + extra_off_resonance, + ) + else: + matrices = system_base_matrix( + parameters, + rf_frequency, + extra_off_resonance, + ) + system_rf_matrix( + parameters, + rf_amplitude, + rf_phase, + rf_frequency, + extra_off_resonance, + ) + else: + matrices = system_base_matrix( + parameters, + rf_frequency, + extra_off_resonance, + ) + system_rf_matrix( + parameters, + rf_amplitude, + rf_phase, + rf_frequency, + extra_off_resonance, + ) + matrices = matrices.unsqueeze(-3) + recovery = c[(None,) * (matrices.ndim - c.ndim - 2) + (...,)].unsqueeze(-2) + recovery = torch.broadcast_to(recovery, (*matrices.shape[:-2], c.shape[-1])) + steps = propagation_step(matrices, recovery, dt_chunk) + for step in steps: + state = apply_propagation_step(state, step) + return state + + n_chunks = (time + chunk_size - 1) // chunk_size + for rf_amplitude_chunk, rf_phase_chunk, rf_frequency_chunk, dt_chunk, extra_off_resonance_chunk in zip( + rf_amplitude.tensor_split(n_chunks), + rf_phase.tensor_split(n_chunks), + rf_frequency.tensor_split(n_chunks), + dt.tensor_split(n_chunks), + extra_off_resonance.tensor_split(n_chunks), + strict=True, + ): + state = activation_checkpoint( + run_chunk, + state, + rf_amplitude_chunk, + rf_phase_chunk, + rf_frequency_chunk, + dt_chunk, + extra_off_resonance_chunk, + use_reentrant=False, + preserve_rng_state=False, + ) + return state, () + + +class SliceSelectiveRFBlock(BMCBlock): + """Piecewise RF block with an effective image-space slice profile.""" + + def __init__( + self, + rf_amplitude: torch.Tensor, + slice_profile: SliceProfileBase, + positions: SpatialDimension[torch.Tensor], + rf_phase: torch.Tensor | float = 0.0, + rf_frequency: torch.Tensor | float = 0.0, + dt: torch.Tensor | float = 0.0, + ) -> None: + """Initialize the block. + + Parameters + ---------- + rf_amplitude + RF amplitude. Shape ``(time, ...)``. + slice_profile + Slice profile evaluated at ``positions.z`` to scale excitation across isochromats. + positions + Isochromat positions in meters for ``(z, y, x)``. + rf_phase + RF phase in rad. Shape ``(time, ...)``, ``(1, ...)`` or scalar. + rf_frequency + RF frequency in Hz. Shape ``(time, ...)``, ``(1, ...)`` or scalar. + dt + Sample duration in seconds. Shape ``(time, ...)``, ``(1, ...)`` or scalar. + """ + super().__init__() + self._block = PiecewiseRFBlock(rf_amplitude=rf_amplitude, rf_phase=rf_phase, rf_frequency=rf_frequency, dt=dt) + self.slice_profile = slice_profile + self.positions = positions.apply(torch.as_tensor) + self.n_iso = max(int(self.positions.z.numel()), int(self.positions.y.numel()), int(self.positions.x.numel())) + if any(axis.numel() not in (1, self.n_iso) for axis in (self.positions.z, self.positions.y, self.positions.x)): + raise ValueError( + 'positions.x, positions.y and positions.z must each have length 1 or match the isochromat count.' + ) + + @property + def duration(self) -> torch.Tensor: + """Duration of the block.""" + return self._block.duration + + def forward(self, parameters: Parameters, state: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + """Apply the block.""" + n_iso = state.shape[-3] + if n_iso != self.n_iso: + raise ValueError( + f'Isochromat axis mismatch: state has {n_iso} isochromats but positions define {self.n_iso}.' + ) + + profile = self.slice_profile(self.positions.z.to(state)) + effective_rf_amplitude = self._block.rf_amplitude[..., None] * profile + if parameters.relative_b1 is not None: + effective_rf_amplitude = effective_rf_amplitude * parameters.relative_b1[..., None] + effective_block = PiecewiseRFBlock( + rf_amplitude=effective_rf_amplitude, + rf_phase=self._block.rf_phase, + rf_frequency=self._block.rf_frequency, + dt=self._block.dt, + extra_off_resonance=self._block.extra_off_resonance, + ) + effective_parameters = Parameters( + equilibrium_magnetization=parameters.equilibrium_magnetization, + t1=parameters.t1, + t2=parameters.t2, + exchange_rate=parameters.exchange_rate, + chemical_shift=parameters.chemical_shift, + static_off_resonance=parameters.static_off_resonance, + relative_b1=None, + mt_saturation=parameters.mt_saturation, + ) + state, outputs = effective_block(effective_parameters, state.unsqueeze(-3)) + return state.squeeze(-3), outputs + + +class DelayBlock(BMCBlock): + """Delay without RF.""" + + def __init__(self, duration: torch.Tensor | float) -> None: + """Initialize the block. + + Parameters + ---------- + duration + Duration in seconds. Shape ``(..., pools)``. + """ + super().__init__() + self._duration = torch.as_tensor(duration) + + @property + def duration(self) -> torch.Tensor: + """Duration of the block.""" + return self._duration + + def forward( + self, + parameters: Parameters, + state: torch.Tensor, + *, + zero_matrix: torch.Tensor | None = None, + zero_c: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + """Apply the block. + + Parameters + ---------- + parameters + Simulation parameters. + state + State tensor. Shape ``(..., isochromats, pools, 3)``. + zero_matrix + Cached no-RF system matrix for the current sequence execution, if available. + zero_c + Cached no-RF recovery vector for the current sequence execution, if available. + + Returns + ------- + state + State tensor. Shape ``(..., pools, 3)``. + """ + if zero_matrix is None or zero_c is None: + matrix, c = system_matrix(parameters, 0.0, 0.0, 0.0) + else: + matrix, c = zero_matrix, zero_c + state = propagate(state, matrix, c, self._duration.to(state)) + return state, () + + +class GradientBlock(BMCBlock): + """Gradient-only block with position-dependent phase accrual.""" + + def __init__( + self, + duration: torch.Tensor | float, + positions: SpatialDimension[torch.Tensor], + gradient_x: torch.Tensor | float = 0.0, + gradient_y: torch.Tensor | float = 0.0, + gradient_z: torch.Tensor | float = 0.0, + ) -> None: + """Initialize the block. + + Parameters + ---------- + duration + Duration in seconds. + gradient_x + Gradient amplitude in T/m along x. + gradient_y + Gradient amplitude in T/m along y. + gradient_z + Gradient amplitude in T/m along z. + positions + Isochromat positions in meters for ``(z, y, x)``. + """ + super().__init__() + self._duration = torch.as_tensor(duration) + self.gradient_x = torch.as_tensor(gradient_x) + self.gradient_y = torch.as_tensor(gradient_y) + self.gradient_z = torch.as_tensor(gradient_z) + self.positions = positions.apply(torch.as_tensor) + self.n_iso = max(int(self.positions.z.numel()), int(self.positions.y.numel()), int(self.positions.x.numel())) + if any(axis.numel() not in (1, self.n_iso) for axis in (self.positions.z, self.positions.y, self.positions.x)): + raise ValueError( + 'positions.x, positions.y and positions.z must each have length 1 or match the isochromat count.' + ) + + @property + def duration(self) -> torch.Tensor: + """Duration of the block.""" + return self._duration + + def forward(self, parameters: Parameters, state: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + """Apply the block.""" + n_iso = state.shape[-3] + if n_iso != self.n_iso: + raise ValueError( + f'Isochromat axis mismatch: state has {n_iso} isochromats but positions define {self.n_iso}.' + ) + + extra_off_resonance = gradient_to_extra_off_resonance( + self.gradient_z, + self.gradient_y, + self.gradient_x, + self.positions, + ) + + matrix = system_base_matrix(parameters, 0.0, extra_off_resonance) + c = system_recovery_vector(parameters) + c = c.unsqueeze(-2) + duration = self._duration.unsqueeze(-1) + state = apply_propagation_step(state, propagation_step(matrix, c, duration)) + return state, () + + +class SpoilBlock(DelayBlock): + """Perfect spoiling with non-zero duration.""" + + def forward( + self, + parameters: Parameters, + state: torch.Tensor, + *, + zero_matrix: torch.Tensor | None = None, + zero_c: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + """Apply the block. + + Parameters + ---------- + parameters + Simulation parameters. + state + State tensor. Shape ``(..., pools, 3)``. + zero_matrix + Cached no-RF system matrix for the current sequence execution, if available. + zero_c + Cached no-RF recovery vector for the current sequence execution, if available. + + Returns + ------- + state + State tensor. Shape ``(..., pools, 3)``. + """ + state, out = super().forward(parameters, state, zero_matrix=zero_matrix, zero_c=zero_c) + mx, _, mz = state.unbind(-1) + z = torch.zeros_like(mx) + return torch.stack([z, z, mz], dim=-1), out + + +class AcquisitionBlock(BMCBlock): + """Acquisition block that emits a readout.""" + + def __init__(self, pool_index: int | None = None) -> None: + """Initialize the block. + + Parameters + ---------- + pool_index + Pool index to read out. If ``None``, emit transverse signal for all pools. + """ + super().__init__() + self.pool_index = pool_index + + def forward(self, parameters: Parameters, state: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + """Apply the block. + + Parameters + ---------- + parameters + Simulation parameters. + state + State tensor. Shape ``(..., pools, 3)``. + + Returns + ------- + state + State tensor. Shape ``(..., isochromats, pools, 3)``. + """ + signal = transverse_readout(state) + if self.pool_index is None: + return state, (signal,) + if not (0 <= self.pool_index < parameters.n_pools): + raise ValueError('pool_index out of bounds.') + return state, (signal[..., self.pool_index],) + + +class LongitudinalReadoutBlock(BMCBlock): + """Read out longitudinal magnetization of a selected pool.""" + + def __init__(self, pool_index: int = 0) -> None: + """Initialize the block. + + Parameters + ---------- + pool_index + Pool index to read out. + """ + super().__init__() + self.pool_index = pool_index + + def forward(self, parameters: Parameters, state: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + """Apply the block.""" + if not (0 <= self.pool_index < parameters.n_pools): + raise ValueError('pool_index out of bounds.') + return state, (state[..., self.pool_index, 2].mean(dim=-1),) + + +class ResetBlock(BMCBlock): + """Reset state to equilibrium or to a provided state.""" + + def __init__(self, state: torch.Tensor | None = None) -> None: + """Initialize the block. + + Parameters + ---------- + state + State tensor. Shape ``(..., pools, 3)``. + """ + super().__init__() + self.state = state + + @property + def duration(self) -> torch.Tensor: + """Duration of the block.""" + return torch.as_tensor(0.0) + + def forward(self, parameters: Parameters, state: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + """Apply the block. + + Parameters + ---------- + parameters + Simulation parameters. + state + State tensor. Shape ``(..., isochromats, pools, 3)``. + + Returns + ------- + state + State tensor. Shape ``(..., isochromats, pools, 3)``. + """ + if self.state is None: + equilibrium = initial_state(parameters).to(state) + batch = torch.broadcast_shapes(equilibrium.shape[:-3], state.shape[:-3]) + equilibrium = torch.broadcast_to( + equilibrium, + (*batch, state.shape[-3], equilibrium.shape[-2], equilibrium.shape[-1]), + ) + return equilibrium.clone(), () + reset_state = self.state.to(state) + if reset_state.ndim == state.ndim - 1: + reset_state = reset_state.unsqueeze(-3) + batch = torch.broadcast_shapes(reset_state.shape[:-3], state.shape[:-3]) + reset_state = torch.broadcast_to( + reset_state, + (*batch, state.shape[-3], reset_state.shape[-2], reset_state.shape[-1]), + ) + return reset_state.clone(), () + + +class BMCSequence(torch.nn.ModuleList, BMCBlock): + """Sequence of Bloch-McConnell blocks.""" + + def __init__(self, blocks: Sequence[BMCBlock] = ()) -> None: + """Initialize the sequence. + + Parameters + ---------- + blocks + Sequence of Bloch-McConnell blocks. + """ + torch.nn.ModuleList.__init__(self, blocks) + + @property + def duration(self) -> torch.Tensor: + """Duration of the sequence.""" + return sum( + (b.duration for b in self if isinstance(b, BMCBlock)), + start=torch.as_tensor(0.0), + ) + + def forward( + self, + parameters: Parameters, + state: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + """Apply the sequence of blocks. + + Parameters + ---------- + parameters + Simulation parameters. + state + State tensor. Shape ``(..., isochromats, pools, 3)``. + + Returns + ------- + state + State tensor. Shape ``(..., pools, 3)``. + outputs + List of output tensors. + """ + parameters = parameters.to(state, copy=False) + zero_matrix: torch.Tensor | None = None + zero_c: torch.Tensor | None = None + outputs: list[torch.Tensor] = [] + for block in self: + assert isinstance(block, BMCBlock) # noqa: S101 + if isinstance(block, DelayBlock | SpoilBlock): + if zero_matrix is None or zero_c is None: + zero_matrix, zero_c = system_matrix(parameters, 0.0, 0.0, 0.0) + state, out = block(parameters, state, zero_matrix=zero_matrix, zero_c=zero_c) + else: + state, out = block(parameters, state) + outputs.extend(out) + return state, tuple(outputs) diff --git a/src/mrpro/operators/models/__init__.py b/src/mrpro/operators/models/__init__.py index e597b753c..055721c12 100644 --- a/src/mrpro/operators/models/__init__.py +++ b/src/mrpro/operators/models/__init__.py @@ -9,11 +9,13 @@ from mrpro.operators.models.MonoExponentialDecay import MonoExponentialDecay from mrpro.operators.models.cMRF import CardiacFingerprinting from mrpro.operators.models.TransientSteadyStateWithPreparation import TransientSteadyStateWithPreparation +from mrpro.operators.models import BMC from mrpro.operators.models import EPG from mrpro.operators.models.MESE import MultiEchoSpinEcho from mrpro.operators.models.NeuroMRF import NeuroMRF __all__ = [ + "BMC", "CardiacFingerprinting", "EPG", "InversionRecovery", @@ -26,4 +28,4 @@ "TransientSteadyStateWithPreparation", "WASABI", "WASABITI" -] \ No newline at end of file +] diff --git a/src/mrpro/utils/slice_profiles.py b/src/mrpro/utils/slice_profiles.py index 31f7b5f85..f7a8a6057 100644 --- a/src/mrpro/utils/slice_profiles.py +++ b/src/mrpro/utils/slice_profiles.py @@ -8,8 +8,17 @@ import torch from mrpro.utils.TensorAttributeMixin import TensorAttributeMixin +from mrpro.utils.unit_conversion import GYROMAGNETIC_RATIO_PROTON -__all__ = ['SliceGaussian', 'SliceInterpolate', 'SliceProfileBase', 'SliceSmoothedRectangular'] +__all__ = [ + 'GaussianRFPulse', + 'SincRFPulse', + 'SliceGaussian', + 'SliceInterpolate', + 'SliceProfileBase', + 'SliceRFPulseBase', + 'SliceSmoothedRectangular', +] class SliceProfileBase(abc.ABC, TensorAttributeMixin, torch.nn.Module): @@ -37,6 +46,109 @@ def random_sample(self, size: Sequence[int]) -> torch.Tensor: raise NotImplementedError +class SliceRFPulseBase(abc.ABC, TensorAttributeMixin, torch.nn.Module): + """Base class for slice-selective RF pulse templates.""" + + @abc.abstractmethod + def forward( + self, flip_angle: torch.Tensor | float, duration: torch.Tensor | float, dt: torch.Tensor | float + ) -> torch.Tensor: + """Create a discrete RF waveform in Tesla.""" + raise NotImplementedError + + def rf_and_phase( + self, + flip_angle: torch.Tensor | float, + duration: torch.Tensor | float, + dt: torch.Tensor | float, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Create a discrete RF waveform and phase in rad.""" + rf = self(flip_angle=flip_angle, duration=duration, dt=dt) + return rf, torch.zeros_like(rf) + + +def _n_samples(duration: torch.Tensor | float, dt: torch.Tensor | float) -> int: + duration_value = torch.as_tensor(duration).item() + dt_value = torch.as_tensor(dt).item() + if duration_value <= 0 or dt_value <= 0: + raise ValueError('duration and dt must be positive.') + samples = round(duration_value / dt_value) + if samples < 1: + raise ValueError('duration / dt must produce at least one RF sample.') + return samples + + +def _scale_waveform_to_flip_angle( + waveform: torch.Tensor, + flip_angle: torch.Tensor | float, + dt: torch.Tensor | float, +) -> torch.Tensor: + flip_angle = torch.as_tensor(flip_angle, dtype=waveform.dtype, device=waveform.device) + dt = torch.as_tensor(dt, dtype=waveform.dtype, device=waveform.device) + return waveform * (flip_angle / (GYROMAGNETIC_RATIO_PROTON * dt * waveform.sum())) + + +class GaussianRFPulse(SliceRFPulseBase): + """Gaussian RF pulse template.""" + + fwhm_fraction: torch.Tensor + + def __init__(self, fwhm_fraction: float | torch.Tensor = 0.35): + """Initialize the Gaussian pulse template. + + Parameters + ---------- + fwhm_fraction + RF Gaussian FWHM relative to pulse duration. + """ + super().__init__() + self.register_buffer('fwhm_fraction', torch.as_tensor(fwhm_fraction)) + + def forward( + self, flip_angle: torch.Tensor | float, duration: torch.Tensor | float, dt: torch.Tensor | float + ) -> torch.Tensor: + """Create a Gaussian RF waveform in Tesla.""" + samples = _n_samples(duration, dt) + duration = torch.as_tensor(duration) + time = torch.linspace(-0.5, 0.5, samples, dtype=duration.dtype, device=duration.device) + sigma = self.fwhm_fraction / (2 * (2 * log(2)) ** 0.5) + waveform = torch.exp(-0.5 * (time / sigma) ** 2) + return _scale_waveform_to_flip_angle(waveform, flip_angle, dt) + + +class SincRFPulse(SliceRFPulseBase): + """Apodized sinc RF pulse template.""" + + time_bandwidth: torch.Tensor + apodization: torch.Tensor + + def __init__(self, time_bandwidth: float | torch.Tensor = 4.0, apodization: float | torch.Tensor = 0.5): + """Initialize the sinc pulse template. + + Parameters + ---------- + time_bandwidth + Time-bandwidth product of the sinc pulse. + apodization + Raised-cosine apodization in ``[0, 1]``. + """ + super().__init__() + self.register_buffer('time_bandwidth', torch.as_tensor(time_bandwidth)) + self.register_buffer('apodization', torch.as_tensor(apodization)) + + def forward( + self, flip_angle: torch.Tensor | float, duration: torch.Tensor | float, dt: torch.Tensor | float + ) -> torch.Tensor: + """Create an apodized sinc RF waveform in Tesla.""" + samples = _n_samples(duration, dt) + duration = torch.as_tensor(duration) + time = torch.linspace(-0.5, 0.5, samples, dtype=duration.dtype, device=duration.device) + sinc = torch.sinc(self.time_bandwidth.to(time) * time) + window = 1 - self.apodization.to(time) + self.apodization.to(time) * torch.cos(2 * torch.pi * time) + waveform = sinc * window + return _scale_waveform_to_flip_angle(waveform, flip_angle, dt) + + class SliceGaussian(SliceProfileBase): """Gaussian slice profile.""" diff --git a/tests/operators/models/test_bmc.py b/tests/operators/models/test_bmc.py new file mode 100644 index 000000000..b84751d39 --- /dev/null +++ b/tests/operators/models/test_bmc.py @@ -0,0 +1,1009 @@ +"""Tests for BMsim Bloch-McConnell sequence models.""" + +from collections.abc import Sequence + +import pytest +import torch +from mrpro.data.SpatialDimension import SpatialDimension +from mrpro.operators.models.BMC import ( + AcquisitionBlock, + BMCSequence, + ConstantRFBlock, + DelayBlock, + GradientBlock, + LongitudinalReadoutBlock, + Parameters, + PiecewiseRFBlock, + ResetBlock, + SliceSelectiveRFBlock, + SpoilBlock, + gradient_to_extra_off_resonance, + initial_state, +) +from mrpro.operators.SignalModel import SignalModel +from mrpro.utils.slice_profiles import ( + GaussianRFPulse, + SincRFPulse, + SliceGaussian, + SliceRFPulseBase, + SliceSmoothedRectangular, +) +from mrpro.utils.unit_conversion import GYROMAGNETIC_RATIO_PROTON, magnetic_field_to_lamor_frequency + +# fmt: off +REFERENCE = { + 1: [0.999999725, 0.997037526, 0.996999091, 0.996957322, 0.996916546, 0.996874953, 0.996830535, 0.996787948, 0.996742052, 0.996696512, 0.996650125, 0.996600829, 0.99655297, 0.99650272, 0.996450994, 0.996400371, 0.996345566, 0.996291411, 0.996236946, 0.996178615, 0.996120749, 0.996061895, 0.995999279, 0.995937802, 0.995874367, 0.99580727, 0.995740451, 0.995673261, 0.995601302, 0.995529165, 0.995457207, 0.995380092, 0.995301711, 0.99522439, 0.995142152, 0.99505688, 0.994971427, 0.994885228, 0.994793485, 0.99469933, 0.994605829, 0.994508622, 0.994406833, 0.994301924, 0.994196715, 0.99408874, 0.993976273, 0.99385794, 0.993738675, 0.99361716, 0.993493045, 0.993362714, 0.993227731, 0.993088635, 0.992946, 0.992800253, 0.992649998, 0.992495353, 0.992334151, 0.992167521, 0.991995382, 0.991815626, 0.991632027, 0.991442331, 0.991245544, 0.991042576, 0.990832521, 0.990615167, 0.990389556, 0.99015606, 0.989914647, 0.98966289, 0.98940384, 0.989135144, 0.988855634, 0.988563955, 0.988258607, 0.987939792, 0.987610101, 0.987271105, 0.986917702, 0.986542716, 0.986151511, 0.98575259, 0.985329487, 0.984882702, 0.984428464, 0.983939336, 0.983437563, 0.982905152, 0.982352491, 0.981764643, 0.981160524, 0.98050853, 0.979834899, 0.979129259, 0.978375245, 0.977578219, 0.976741414, 0.975859615, 0.974926597, 0.973937906, 0.972891887, 0.971788738, 0.970615354, 0.969347641, 0.968021866, 0.966590256, 0.965084311, 0.963447279, 0.961701924, 0.959837223, 0.957835087, 0.955693565, 0.953382029, 0.950862001, 0.948155283, 0.945232573, 0.942040943, 0.938574182, 0.93481523, 0.930704225, 0.926179173, 0.921187814, 0.915724556, 0.909690314, 0.903037007, 0.895597655, 0.887384453, 0.878097754, 0.867736655, 0.856126995, 0.842965626, 0.827976175, 0.810984686, 0.791715357, 0.769519727, 0.744002894, 0.714601904, 0.680711679, 0.641357215, 0.596238563, 0.544658719, 0.485582702, 0.419503889, 0.346013379, 0.268588682, 0.190166045, 0.117143641, 0.055823878, 0.015334864, 0.001867055, 0.015133695, 0.05484406, 0.114463081, 0.184882753, 0.26000293, 0.333732373, 0.403350114, 0.465533801, 0.520745963, 0.568490967, 0.609809063, 0.645352873, 0.675462055, 0.701208651, 0.72341188, 0.742954099, 0.760643992, 0.777530675, 0.794115671, 0.810490176, 0.826523618, 0.841942266, 0.856362016, 0.869381311, 0.881079159, 0.891371009, 0.900447082, 0.908417745, 0.915464929, 0.921670226, 0.927157187, 0.932044587, 0.93644288, 0.940405547, 0.94397017, 0.947213781, 0.950186662, 0.952880517, 0.95534793, 0.957628661, 0.959732543, 0.961684255, 0.96350057, 0.965162722, 0.966730357, 0.968176197, 0.969549343, 0.9708158, 0.972001824, 0.973121661, 0.9741762, 0.975168133, 0.976102785, 0.976987139, 0.977827147, 0.978620246, 0.979361309, 0.98006685, 0.98074815, 0.981379095, 0.981991194, 0.982566468, 0.983119247, 0.983640129, 0.984146574, 0.984616928, 0.985078705, 0.985515431, 0.985927189, 0.986330359, 0.986716288, 0.987079889, 0.98742852, 0.987767233, 0.988094412, 0.988407528, 0.988706474, 0.98899282, 0.989267961, 0.989533098, 0.98979054, 0.990037344, 0.990275919, 0.990506265, 0.990728133, 0.990942449, 0.99114944, 0.991349987, 0.991543283, 0.991730293, 0.991913285, 0.992088506, 0.99225807, 0.992422066, 0.992579401, 0.992732167, 0.992880292, 0.993025146, 0.993166409, 0.99330348, 0.993435823, 0.993561866, 0.993685165, 0.993806123, 0.993926112, 0.9940402, 0.994149711, 0.99425632, 0.994362605, 0.994465758, 0.994564282, 0.994658979, 0.994754289, 0.994847192, 0.994934516, 0.995020978, 0.995107244, 0.995190491, 0.995268746, 0.995347991, 0.995425995, 0.995498816, 0.995571717, 0.99564446, 0.995712438, 0.995779939, 0.995847738, 0.995911886, 0.995973974, 0.996037227, 0.996096724, 0.996155133, 0.996214054, 0.996269093, 0.996323743, 0.996379092, 0.996430215, 0.996482391, 0.996533152, 0.996581439, 0.996631167, 0.996678016, 0.99672393, 0.996770254, 0.996813236, 0.996858009, 0.9969, 0.996941102, 0.996983241, 0.997022003], # noqa: E501 + 2: [0.999999834, 0.997186629, 0.997927991, 0.996945247, 0.998023983, 0.996787008, 0.998069423, 0.996613949, 0.998045414, 0.996493655, 0.997997875, 0.996404321, 0.997921128, 0.996361714, 0.997785488, 0.996303164, 0.997600352, 0.99631996, 0.997462602, 0.996273639, 0.997223261, 0.996322854, 0.997068795, 0.996383121, 0.996801296, 0.996326443, 0.996645368, 0.996380381, 0.996367039, 0.996293652, 0.996089239, 0.996322385, 0.995947521, 0.996334417, 0.995679767, 0.996187305, 0.995557901, 0.996156233, 0.995305174, 0.995955071, 0.995064257, 0.995867943, 0.994986487, 0.995592918, 0.994773764, 0.995433828, 0.994752357, 0.995234719, 0.994588948, 0.994801744, 0.994458256, 0.994484976, 0.994560493, 0.993904521, 0.994515603, 0.993445817, 0.994714535, 0.99291151, 0.994747972, 0.992149798, 0.994934003, 0.991543228, 0.994912409, 0.990919299, 0.994777219, 0.990521769, 0.994336245, 0.990509797, 0.993571472, 0.990750126, 0.992117869, 0.991332925, 0.990540173, 0.992353479, 0.988752973, 0.992982999, 0.987514167, 0.992675955, 0.987448127, 0.990895509, 0.988795021, 0.98785849, 0.990722767, 0.985021701, 0.99120227, 0.984664701, 0.988380779, 0.987371125, 0.983516853, 0.989734082, 0.981654216, 0.986600195, 0.985581449, 0.979892362, 0.987826434, 0.980281964, 0.980267103, 0.986345595, 0.976234926, 0.979681559, 0.984337172, 0.973280216, 0.976310428, 0.983030535, 0.972298692, 0.969069946, 0.979190168, 0.976589488, 0.964117418, 0.964106963, 0.974012229, 0.975067187, 0.964173628, 0.953489793, 0.951519377, 0.956328011, 0.96196808, 0.964859665, 0.964810796, 0.963131894, 0.960720897, 0.956852138, 0.948144786, 0.93041542, 0.9098532, 0.914648444, 0.941191307, 0.902060495, 0.902924311, 0.897794407, 0.906406806, 0.844314807, 0.837731462, 0.827817913, 0.791135605, 0.83598323, 0.772931812, 0.826446397, 0.766362064, 0.684086248, 0.598773545, 0.752622976, 0.687224584, 0.683461119, 0.633510343, 0.541020164, 0.489942119, -0.004644275, 0.147154823, 0.191412109, 0.117639452, -0.138237114, 0.117756436, 0.190129604, 0.144400355, -0.009116538, 0.483271619, 0.532766098, 0.622484731, 0.670492193, 0.67237748, 0.735837956, 0.579465689, 0.662796529, 0.743796273, 0.801950629, 0.74682652, 0.808253879, 0.762971686, 0.799944257, 0.810956592, 0.819404113, 0.884063385, 0.878553329, 0.886370772, 0.888335862, 0.929686995, 0.905053653, 0.901966589, 0.923897633, 0.942692098, 0.952251646, 0.956819947, 0.959811662, 0.961977302, 0.962435326, 0.959885075, 0.954520633, 0.949923267, 0.952055428, 0.962886058, 0.973940689, 0.973043666, 0.963242646, 0.963305621, 0.975850301, 0.978557631, 0.968497867, 0.971743392, 0.982538561, 0.975885219, 0.972859548, 0.983951912, 0.979354089, 0.975904509, 0.986044195, 0.980008731, 0.980012947, 0.987593118, 0.979679325, 0.985361083, 0.986420315, 0.98146738, 0.989562364, 0.983366981, 0.987209232, 0.988251216, 0.984524649, 0.991078631, 0.984906902, 0.990602851, 0.98776232, 0.98868429, 0.990808633, 0.987351471, 0.992592511, 0.987432567, 0.992901552, 0.988684063, 0.992275334, 0.990480273, 0.99126011, 0.992063569, 0.990684125, 0.993520113, 0.990450875, 0.994287083, 0.990470426, 0.994729041, 0.990873969, 0.994866103, 0.991503729, 0.99488971, 0.992114455, 0.994705338, 0.992879988, 0.994674253, 0.993417036, 0.994477461, 0.993878056, 0.994524691, 0.994460558, 0.994424676, 0.99477905, 0.994557442, 0.995213486, 0.994722743, 0.995414021, 0.994746011, 0.995574437, 0.994960286, 0.995850555, 0.995039626, 0.995938858, 0.9952818, 0.996140974, 0.995535684, 0.996173088, 0.99565873, 0.996321047, 0.995927471, 0.996309907, 0.996070184, 0.996281946, 0.996348884, 0.996369353, 0.996628134, 0.996315981, 0.996784909, 0.996373201, 0.997053357, 0.996313265, 0.997208637, 0.996264225, 0.997449004, 0.996310849, 0.997587567, 0.996294001, 0.997773808, 0.996352706, 0.997910589, 0.996395095, 0.997988157, 0.996484482, 0.998036813, 0.996604492, 0.998061579, 0.996777312, 0.99801711, 0.996935617, 0.99792193, 0.997176929], # noqa: E501 + 3: [0.987722603, 0.524032356, 0.523413286, 0.52279303, 0.522172487, 0.52155109, 0.520928268, 0.520304913, 0.519680434, 0.519054252, 0.518427248, 0.517798828, 0.517168886, 0.516536819, 0.515903488, 0.515268285, 0.514630582, 0.513991247, 0.513349651, 0.512705638, 0.512058558, 0.511409271, 0.510757109, 0.510101397, 0.509442974, 0.508781153, 0.508115744, 0.507446024, 0.506772836, 0.506095449, 0.505413091, 0.504726605, 0.504035214, 0.503338118, 0.502636152, 0.501928484, 0.501214831, 0.50049433, 0.499767794, 0.499034337, 0.498293041, 0.497544698, 0.496788365, 0.496023655, 0.495249552, 0.494466853, 0.493674497, 0.492871395, 0.492058321, 0.491234143, 0.49039768, 0.489549671, 0.488688893, 0.487814738, 0.486925866, 0.486023004, 0.48510477, 0.484169674, 0.483218421, 0.48224948, 0.481262005, 0.48025432, 0.479227055, 0.478178458, 0.477106654, 0.476012244, 0.474893262, 0.47374853, 0.472575842, 0.47137573, 0.4701459, 0.468883901, 0.467590164, 0.4662619, 0.464897492, 0.463495261, 0.462053322, 0.460569693, 0.459042289, 0.457468919, 0.455847256, 0.454174872, 0.452449204, 0.450667703, 0.448827244, 0.446924986, 0.444957795, 0.442922377, 0.440815276, 0.438632878, 0.436371363, 0.434026768, 0.431594937, 0.429071535, 0.426452056, 0.423731989, 0.420906057, 0.417969417, 0.414916873, 0.411743032, 0.408442399, 0.405009337, 0.401438094, 0.397722815, 0.39385758, 0.389836331, 0.385652994, 0.381301394, 0.37677563, 0.372068557, 0.367173933, 0.362084882, 0.356794162, 0.351293986, 0.345575947, 0.339630828, 0.333448393, 0.327017346, 0.320325198, 0.313358194, 0.306102061, 0.298539812, 0.290655518, 0.282432395, 0.273854043, 0.264905067, 0.255571797, 0.245843286, 0.235712216, 0.22517585, 0.214237289, 0.20290658, 0.191201969, 0.179152225, 0.166794238, 0.154179984, 0.141374028, 0.128455476, 0.115518471, 0.102672276, 0.090040725, 0.077760797, 0.065980583, 0.05485626, 0.044548609, 0.035217132, 0.027016481, 0.020089613, 0.014562249, 0.010537353, 0.008090503, 0.007266097, 0.008075353, 0.010495788, 0.01447248, 0.019920754, 0.026730555, 0.03477152, 0.043898822, 0.053958886, 0.064796301, 0.076257429, 0.088196225, 0.100477596, 0.112980219, 0.125598173, 0.138241713, 0.150836853, 0.163323936, 0.175655499, 0.187792499, 0.199705144, 0.211364609, 0.222744874, 0.233820674, 0.244567333, 0.254961401, 0.26498154, 0.274609988, 0.283834393, 0.292650043, 0.301063206, 0.309095215, 0.316785233, 0.324192296, 0.331382836, 0.338414833, 0.345316738, 0.352078945, 0.358663867, 0.365026393, 0.371131096, 0.376959224, 0.382507333, 0.387782679, 0.392798135, 0.3975704, 0.40211549, 0.406449289, 0.410586613, 0.414541128, 0.418325236, 0.421950242, 0.425426435, 0.428763211, 0.43196915, 0.435052177, 0.438019547, 0.44087776, 0.443633459, 0.446292219, 0.448859397, 0.451339995, 0.453738686, 0.456059824, 0.45830752, 0.460485594, 0.46259764, 0.464647036, 0.466636954, 0.468570226, 0.470449977, 0.472278698, 0.474058887, 0.475792894, 0.477482959, 0.479131178, 0.480739537, 0.482309916, 0.483844093, 0.485343633, 0.486810033, 0.488246142, 0.489651157, 0.491027565, 0.492377703, 0.493700668, 0.494998714, 0.496272933, 0.497525347, 0.498754992, 0.499963808, 0.501153609, 0.502323416, 0.503474975, 0.504609066, 0.505727302, 0.50682863, 0.507914629, 0.508986765, 0.510043976, 0.511087693, 0.512118477, 0.513137647, 0.514144136, 0.515139226, 0.516124153, 0.517097818, 0.518061416, 0.519016117, 0.519960808, 0.520896613, 0.52182389, 0.522743691, 0.523654921, 0.524558607, 0.525455753, 0.52634524, 0.52722805, 0.528104441, 0.528975363, 0.529839687, 0.530698331, 0.531552186, 0.532400122, 0.533243033, 0.534081762, 0.534915187, 0.535744149, 0.536568823, 0.537390011, 0.538206604, 0.539019393, 0.539829161, 0.540634783, 0.541437023, 0.542236029, 0.543032543, 0.54382544, 0.544615461, 0.545403327, 0.546187927, 0.546969976, 0.547749574, 0.548527425, 0.549302402, 0.550075198, 0.550846514, 0.551615218, 0.552381993, 0.553147512, 0.553910644], # noqa: E501 + 4: [0.999903038, 0.47010593, 0.455129322, 0.446606945, 0.444830803, 0.449897134, 0.46170758, 0.479957189, 0.504154088, 0.533607718, 0.567474935, 0.604745663, 0.644311629, 0.684951618, 0.72541419, 0.764404738, 0.800672738, 0.833001739, 0.860286964, 0.881549312, 0.895959567, 0.902886212, 0.901899202, 0.89279392, 0.875601662, 0.85058051, 0.818231376, 0.77926084, 0.734584076, 0.685288251, 0.632609761, 0.577893447, 0.52256042, 0.468065314, 0.415851814, 0.367317087, 0.323764747, 0.286372557, 0.256155335, 0.233934308, 0.220316088, 0.215672389, 0.22013034, 0.233568318, 0.255619934, 0.285683563, 0.322942448, 0.366385596, 0.414838303, 0.466999131, 0.521472162, 0.576813977, 0.631569226, 0.684315115, 0.73370422, 0.778496733, 0.817601368, 0.850098236, 0.875275657, 0.892627275, 0.901889564, 0.903025771, 0.896235403, 0.881943722, 0.860778107, 0.833564266, 0.801278643, 0.76502423, 0.726016625, 0.685506465, 0.644789446, 0.60511898, 0.567719162, 0.533701837, 0.504081384, 0.479705665, 0.461270425, 0.449272807, 0.444023188, 0.445625131, 0.453987321, 0.468822314], # noqa: E501 + 5: [0.99999991, 0.998269816, 0.99821325, 0.998116004, 0.998010688, 0.997936494, 0.99791761, 0.99794804, 0.997992652, 0.998004616, 0.997950957, 0.997832327, 0.997685289, 0.99756501, 0.997515681, 0.99754477, 0.997615672, 0.997664214, 0.997631735, 0.997498303, 0.99729698, 0.997100739, 0.99698547, 0.996988942, 0.997086166, 0.99719703, 0.997224518, 0.997105702, 0.996849146, 0.996537063, 0.996289024, 0.996201278, 0.99629288, 0.996487452, 0.996643499, 0.996621568, 0.99635779, 0.995905416, 0.99542026, 0.995090662, 0.995045014, 0.995278009, 0.995637705, 0.995884619, 0.995802046, 0.995306772, 0.994506389, 0.993670964, 0.993118845, 0.993069285, 0.993521198, 0.994225219, 0.994767378, 0.994743126, 0.993946235, 0.992493744, 0.990814899, 0.989503552, 0.989068684, 0.989694661, 0.991104899, 0.992604964, 0.99331158, 0.992495901, 0.98991351, 0.985992145, 0.981778996, 0.978634055, 0.977745218, 0.979615068, 0.983690576, 0.988280968, 0.990826588, 0.988479048, 0.978855355, 0.960772201, 0.934758674, 0.903198269, 0.870041735, 0.840137911, 0.818322171, 0.808452317, 0.81260542, 0.830580072, 0.859808372, 0.895675961, 0.932175471, 0.96277204, 0.981305326, 0.982812218, 0.964135715, 0.924264637, 0.864391818, 0.787704996, 0.699007585, 0.604195616, 0.509697692, 0.421927155, 0.346794366, 0.289297388, 0.253220723, 0.240917624, 0.253185843, 0.2892278, 0.346690412, 0.421789339, 0.50952667, 0.603992164, 0.698772562, 0.787439286, 0.864096266, 0.923939983, 0.963782538, 0.982430914, 0.980896115, 0.962335023, 0.931710729, 0.895183693, 0.859289037, 0.830034526, 0.812035011, 0.807858911, 0.817708096, 0.839505826, 0.869394429, 0.902538432, 0.934088683, 0.960093951, 0.978170176, 0.987787719, 0.990129438, 0.987578032, 0.982981776, 0.978900362, 0.977024694, 0.977907954, 0.981047668, 0.985255966, 0.989172792, 0.991750825, 0.992562175, 0.991851135, 0.990346476, 0.988931475, 0.988300609, 0.988730521, 0.990036891, 0.991710736, 0.993158169, 0.993949867, 0.993968719, 0.993420904, 0.992710953, 0.992252858, 0.992296021, 0.992841567, 0.993670259, 0.994463732, 0.994951886, 0.995027081, 0.994772501, 0.994404831, 0.994163559, 0.994200646, 0.994521401, 0.994997431, 0.995440366, 0.995694363, 0.99570616, 0.995539626, 0.995334277, 0.995231694, 0.995308385, 0.995545412, 0.995846643, 0.996092584, 0.996201055, 0.996163506, 0.996042845, 0.995936081, 0.995923333, 0.996029617, 0.996217241, 0.996410423, 0.99653636, 0.996562191, 0.996508103, 0.996433042, 0.996401483, 0.996450339, 0.996572423, 0.99672379, 0.996849463, 0.996913013, 0.996913849, 0.99688496, 0.996873109, 0.99691333, 0.997011436, 0.997143002, 0.99726857, 0.997355179], # noqa: E501 + 6: [0.99999759, 0.999041857, 0.999030444, 0.999017417, 0.998929864, 0.998990419, 0.998976447, 0.998960564, 0.998947839, 0.998932861, 0.998917627, 0.998902469, 0.998886095, 0.998870496, 0.998854095, 0.998828293, 0.998820025, 0.998802423, 0.998781888, 0.998766091, 0.99874711, 0.998727536, 0.998708338, 0.998687388, 0.998667378, 0.998646422, 0.998615891, 0.99860258, 0.998579901, 0.99854771, 0.998532893, 0.998508311, 0.998482568, 0.998457817, 0.998430912, 0.998404321, 0.998376833, 0.998344152, 0.998319175, 0.998289252, 0.99824419, 0.998226929, 0.998194302, 0.998159782, 0.998126851, 0.99809069, 0.998054876, 0.998018038, 0.997972971, 0.99794, 0.997899419, 0.997790922, 0.997814459, 0.997769843, 0.997722364, 0.997676976, 0.99762747, 0.997577451, 0.997526075, 0.997465109, 0.997416953, 0.997359918, 0.997117356, 0.997239836, 0.997176357, 0.997107565, 0.997043806, 0.996972739, 0.996900412, 0.99682606, 0.996737263, 0.996667033, 0.996583297, 0.996275925, 0.996405909, 0.996311464, 0.996209838, 0.9961129, 0.996005321, 0.995895769, 0.995782527, 0.995644162, 0.995537739, 0.995408152, 0.995028936, 0.995130754, 0.994981722, 0.994819785, 0.994665055, 0.994488048, 0.994313551, 0.994129506, 0.993894127, 0.993726106, 0.993508999, 0.993127392, 0.993040799, 0.992785115, 0.992511123, 0.992237404, 0.991916293, 0.991618684, 0.991287331, 0.990086065, 0.990553695, 0.990151457, 0.989679133, 0.989273087, 0.988770326, 0.98825685, 0.987715402, 0.986923625, 0.986480536, 0.985795597, 0.98442128, 0.984264725, 0.983379194, 0.982440079, 0.981466219, 0.980197288, 0.979158275, 0.977866771, 0.974945959, 0.974878408, 0.9731116, 0.971163462, 0.969164454, 0.963283089, 0.964229193, 0.961311042, 0.957401028, 0.954502241, 0.94896279, 0.945695337, 0.940387818, 0.932357441, 0.927593288, 0.893122744, 0.910343958, 0.89783638, 0.886092144, 0.871937151, 0.847060107, 0.834710852, 0.797153005, 0.778551664, 0.673298191, 0.365174208, 0.631588655, 0.519661751, -0.146457275, -0.355474549, -0.146767833, 0.51655868, 0.622551055, 0.359633508, 0.660600777, 0.762843581, 0.780806629, 0.816626191, 0.82820531, 0.851856494, 0.864952188, 0.875558478, 0.886697306, 0.869124274, 0.901117601, 0.904946561, 0.912014015, 0.917053345, 0.921515932, 0.930146614, 0.937188178, 0.944781151, 0.950641047, 0.952111892, 0.960012742, 0.963581799, 0.966790375, 0.969554444, 0.970450561, 0.974028953, 0.975873634, 0.97737012, 0.979019003, 0.980312404, 0.981519897, 0.982633191, 0.982983247, 0.984522386, 0.985348968, 0.985914448, 0.98681157, 0.987444662, 0.988038284, 0.98861112, 0.989078689, 0.989605395, 0.990055641, 0.989633365, 0.990869928, 0.991235132, 0.991563163, 0.991911512, 0.992209743, 0.99250592, 0.99278162, 0.99288629, 0.993284511, 0.99351663, 0.993698444, 0.993946309, 0.994141828, 0.9943269, 0.994513589, 0.994677242, 0.99484743, 0.995004059, 0.994908755, 0.995295033, 0.995430663, 0.995542749, 0.995686322, 0.995804435, 0.995918549, 0.996030369, 0.996131269, 0.99623663, 0.996334555, 0.99620746, 0.996518287, 0.996604905, 0.996677882, 0.996769208, 0.996845967, 0.996920576, 0.996993788, 0.997059568, 0.997130291, 0.997195585, 0.997074511, 0.997319023, 0.997377605, 0.997427245, 0.997489595, 0.997542297, 0.997593585, 0.99764429, 0.997690821, 0.997739397, 0.997785052, 0.997762355, 0.997871954, 0.997913441, 0.997947283, 0.997993175, 0.998030805, 0.998067381, 0.998104266, 0.998137894, 0.998173084, 0.998206349, 0.998224221, 0.998269878, 0.998300367, 0.998325895, 0.998359093, 0.998387083, 0.99841416, 0.99844153, 0.998466727, 0.998492905, 0.998517903, 0.99853311, 0.998565699, 0.998588751, 0.998602432, 0.998633301, 0.998654592, 0.998674927, 0.998696188, 0.998715688, 0.998735556, 0.998754819, 0.998770888, 0.998791688, 0.998809545, 0.998818074, 0.998844105, 0.998860737, 0.998876563, 0.998893154, 0.998908523, 0.998923963, 0.998939139, 0.998952054, 0.998968128, 0.998982281, 0.998921988, 0.999009627, 0.99902282, 0.999034397], # noqa: E501 + 7: [0.987774087, 0.529314231, 0.528707525, 0.528099223, 0.52749, 0.526880923, 0.526269887, 0.525658066, 0.525046002, 0.524431517, 0.523816315, 0.52320017, 0.522581723, 0.521961959, 0.521340786, 0.520717033, 0.520091911, 0.51946463, 0.518834686, 0.518203469, 0.517568351, 0.516931116, 0.516291746, 0.515647869, 0.515002168, 0.51435304, 0.513698939, 0.513042937, 0.512382406, 0.511717145, 0.511048338, 0.51037447, 0.509695719, 0.509012361, 0.508322957, 0.507628854, 0.506928953, 0.506221056, 0.505508957, 0.504789348, 0.504061181, 0.503327887, 0.502585511, 0.501834732, 0.501076135, 0.500307537, 0.499530425, 0.498743076, 0.497944258, 0.497136773, 0.496316447, 0.495483503, 0.494640622, 0.493782978, 0.492910564, 0.492026608, 0.491125422, 0.490208945, 0.489277244, 0.488325928, 0.487359505, 0.486372334, 0.485364241, 0.484338761, 0.48328848, 0.482215494, 0.481121921, 0.480000005, 0.478852439, 0.47767933, 0.476472971, 0.475241885, 0.473977455, 0.472677425, 0.471346045, 0.469975486, 0.468565865, 0.467120231, 0.46562749, 0.464093644, 0.462514381, 0.460881281, 0.459202895, 0.457468884, 0.455673762, 0.453828406, 0.451915108, 0.449935352, 0.447894261, 0.44577321, 0.443582389, 0.441314127, 0.438953865, 0.436519237, 0.433988665, 0.431355965, 0.42863893, 0.425806624, 0.422865665, 0.419821389, 0.416643345, 0.413354664, 0.409936372, 0.406367403, 0.402684276, 0.398842397, 0.394840203, 0.390707025, 0.386386747, 0.381905899, 0.377258888, 0.372395614, 0.367377225, 0.362141049, 0.356666513, 0.35102335, 0.345109751, 0.338966063, 0.332610518, 0.325949917, 0.31910676, 0.311989629, 0.30455607, 0.296944413, 0.288946141, 0.280643173, 0.272044409, 0.262969225, 0.253697736, 0.24398421, 0.233863594, 0.22356677, 0.212725911, 0.201808121, 0.190527513, 0.178882256, 0.167363699, 0.155265954, 0.143632374, 0.131618158, 0.11982825, 0.108285266, 0.096495022, 0.085752086, 0.073973473, 0.064186092, 0.050859496, 0.040782131, 0.037820941, 0.027146226, 0.013623523, 0.010801135, 0.013548243, 0.027003593, 0.037679682, 0.040700019, 0.050333145, 0.063533648, 0.073076061, 0.08463352, 0.095106133, 0.106645353, 0.117907586, 0.129458931, 0.141221159, 0.15267075, 0.164604792, 0.176035785, 0.187673988, 0.199024373, 0.210119941, 0.221220936, 0.231879942, 0.242458801, 0.25270386, 0.262578706, 0.272299766, 0.28156814, 0.290546137, 0.299184251, 0.307372713, 0.315288823, 0.322782269, 0.329961547, 0.336897473, 0.343554839, 0.350137666, 0.35676753, 0.363349536, 0.369771104, 0.375848626, 0.381606436, 0.387150689, 0.392374791, 0.397358659, 0.402111878, 0.406622784, 0.410956857, 0.415090236, 0.419031078, 0.422831899, 0.426458049, 0.429937087, 0.433290877, 0.436498158, 0.439594643, 0.442574779, 0.445434068, 0.448206778, 0.450874065, 0.453444292, 0.455940115, 0.45834338, 0.460670576, 0.462928908, 0.46510822, 0.467228731, 0.469284098, 0.471273253, 0.473214855, 0.47509685, 0.476925448, 0.478710958, 0.480444228, 0.482134423, 0.483785267, 0.485390513, 0.486961993, 0.48849573, 0.489991778, 0.49145819, 0.492890058, 0.494289999, 0.495666519, 0.497010559, 0.498330485, 0.499625436, 0.500892646, 0.502140431, 0.503366088, 0.504567315, 0.50575195, 0.506916024, 0.50805964, 0.509188588, 0.510297976, 0.511392653, 0.512471988, 0.513534006, 0.514584935, 0.515620777, 0.516641698, 0.517652792, 0.518651065, 0.519636179, 0.520612487, 0.521577023, 0.522530983, 0.523476049, 0.524410731, 0.525337235, 0.526254775, 0.527162667, 0.528065233, 0.528958432, 0.529843551, 0.530724115, 0.5315958, 0.532461128, 0.533321098, 0.534174359, 0.535022334, 0.535864627, 0.536701175, 0.537533433, 0.538360083, 0.539181397, 0.540000225, 0.540813022, 0.541621604, 0.542427654, 0.543228107, 0.544025671, 0.544820299, 0.545609923, 0.546397497, 0.547181397, 0.547962306, 0.548740869, 0.549515808, 0.550288521, 0.551058858, 0.551825829, 0.552590961, 0.553354459, 0.554114342, 0.554873093, 0.555630086, 0.556383952, 0.557137027, 0.557888172, 0.558636565], # noqa: E501 + 8: [0.999820569, 0.87527645, 0.859938291, 0.822520225, 0.75827725, 0.682423934, 0.587279692, 0.497985744, 0.404121707, 0.333886271, 0.274358848, 0.249599237, 0.244799439, 0.274941092, 0.324494648, 0.398250875, 0.481554696, 0.57152178, 0.656069831, 0.729198968, 0.782742224, 0.81190881, 0.813176736, 0.785418611, 0.729627584, 0.64832101, 0.546192759, 0.427470403, 0.299333029, 0.165115777, 0.033253273, -0.095820358, -0.213486696, -0.322811082, -0.415837872, -0.498276213, -0.563469954, -0.617948473, -0.65685748, -0.685494392, -0.701187973, -0.706469123, -0.701534463, -0.686190397, -0.657920809, -0.619370576, -0.565260104, -0.500422299, -0.418325893, -0.325602353, -0.216536731, -0.099056354, 0.029921487, 0.161792009, 0.296136948, 0.424528353, 0.543622746, 0.646234634, 0.728104224, 0.784510655, 0.812887787, 0.812196741, 0.783519447, 0.730319415, 0.657374198, 0.572790198, 0.482607961, 0.398858475, 0.324532419, 0.27424232, 0.243349988, 0.247366631, 0.27144783, 0.330415665, 0.400290775, 0.49402276, 0.583397007, 0.678872016, 0.75520628, 0.820108624, 0.858200265, 0.874270824], # noqa: E501 +} +# fmt: on + +# fmt: off +PULSES = { + 5: torch.tensor([0.49835002422332764, 1.510050654411316, 2.5419647693634033, 3.5943539142608643, 4.667484760284424, 5.761628150939941, 6.877054691314697, 8.014037132263184, 9.172842979431152, 10.353734970092773, 11.556971549987793, 12.782797813415527, 14.031449317932129, 15.303146362304688, 16.598094940185547, 17.916479110717773, 19.258464813232422, 20.624191284179688, 22.013771057128906, 23.42729377746582, 24.86481285095215, 26.326353073120117, 27.81189727783203, 29.321399688720703, 30.854766845703125, 32.411869049072266, 33.99253463745117, 35.59654235839844, 37.22362518310547, 38.87346267700195, 40.54569625854492, 42.239906311035156, 43.955623626708984, 45.69232177734375, 47.44942092895508, 49.226287841796875, 51.02223205566406, 52.83649826049805, 54.66828918457031, 56.516727447509766, 58.38090133666992, 60.25982666015625, 62.152462005615234, 64.0577163696289, 65.97443389892578, 67.90141296386719, 69.8373794555664, 71.7810287475586, 73.7309799194336, 75.6858139038086, 77.64407348632812, 79.60423278808594, 81.5647201538086, 83.52394104003906, 85.4802474975586, 87.43193817138672, 89.37730407714844, 91.31458282470703, 93.24198150634766, 95.1576919555664, 97.05986022949219, 98.94664001464844, 100.81614685058594, 102.66647338867188, 104.49571990966797, 106.30196380615234, 108.08329772949219, 109.8377914428711, 111.56352233886719, 113.25859069824219, 114.92108917236328, 116.54913330078125, 118.14085388183594, 119.69440460205078, 121.20796203613281, 122.67974090576172, 124.10797119140625, 125.4909439086914, 126.82697296142578, 128.1144256591797, 129.35171508789062, 130.5372772216797, 131.669677734375, 132.74745178222656, 133.7692413330078, 134.73374938964844, 135.63975524902344, 136.48606872558594, 137.27162170410156, 137.99537658691406, 138.65638732910156, 139.25381469726562, 139.78684997558594, 140.25482177734375, 140.65707397460938, 140.9931182861328, 141.26248168945312, 141.46481323242188, 141.599853515625, 141.66741943359375, 141.66741943359375, 141.599853515625, 141.46481323242188, 141.26248168945312, 140.9931182861328, 140.65707397460938, 140.25482177734375, 139.78684997558594, 139.25381469726562, 138.65638732910156, 137.99537658691406, 137.27162170410156, 136.48606872558594, 135.63975524902344, 134.73374938964844, 133.7692413330078, 132.74745178222656, 131.669677734375, 130.5372772216797, 129.35171508789062, 128.1144256591797, 126.82697296142578, 125.4909439086914, 124.10797119140625, 122.67974090576172, 121.20796203613281, 119.69440460205078, 118.14085388183594, 116.54913330078125, 114.92108917236328, 113.25859069824219, 111.56352233886719, 109.8377914428711, 108.08329772949219, 106.30196380615234, 104.49571990966797, 102.66647338867188, 100.81614685058594, 98.94664001464844, 97.05986022949219, 95.1576919555664, 93.24198150634766, 91.31458282470703, 89.37730407714844, 87.43193817138672, 85.4802474975586, 83.52394104003906, 81.5647201538086, 79.60423278808594, 77.64407348632812, 75.6858139038086, 73.7309799194336, 71.7810287475586, 69.8373794555664, 67.90141296386719, 65.97443389892578, 64.0577163696289, 62.152462005615234, 60.25982666015625, 58.38090133666992, 56.516727447509766, 54.66828918457031, 52.83649826049805, 51.02223205566406, 49.226287841796875, 47.44942092895508, 45.69232177734375, 43.955623626708984, 42.239906311035156, 40.54569625854492, 38.87346267700195, 37.22362518310547, 35.59654235839844, 33.99253463745117, 32.411869049072266, 30.854766845703125, 29.321399688720703, 27.81189727783203, 26.326353073120117, 24.86481285095215, 23.42729377746582, 22.013771057128906, 20.624191284179688, 19.258464813232422, 17.916479110717773, 16.598094940185547, 15.303146362304688, 14.031449317932129, 12.782797813415527, 11.556971549987793, 10.353734970092773, 9.172842979431152, 8.014037132263184, 6.877054691314697, 5.761628150939941, 4.667484760284424, 3.5943539142608643, 2.5419647693634033, 1.510050654411316, 0.49835002422332764], dtype=torch.float32), # noqa: E501 + 6: torch.tensor([0.49835002422332764, 1.510050654411316, 2.5419647693634033, 3.5943539142608643, 4.667484760284424, 5.761628150939941, 6.877054691314697, 8.014037132263184, 9.172842979431152, 10.353734970092773, 11.556971549987793, 12.782797813415527, 14.031449317932129, 15.303146362304688, 16.598094940185547, 17.916479110717773, 19.258464813232422, 20.624191284179688, 22.013771057128906, 23.42729377746582, 24.86481285095215, 26.326353073120117, 27.81189727783203, 29.321399688720703, 30.854766845703125, 32.411869049072266, 33.99253463745117, 35.59654235839844, 37.22362518310547, 38.87346267700195, 40.54569625854492, 42.239906311035156, 43.955623626708984, 45.69232177734375, 47.44942092895508, 49.226287841796875, 51.02223205566406, 52.83649826049805, 54.66828918457031, 56.516727447509766, 58.38090133666992, 60.25982666015625, 62.152462005615234, 64.0577163696289, 65.97443389892578, 67.90141296386719, 69.8373794555664, 71.7810287475586, 73.7309799194336, 75.6858139038086, 77.64407348632812, 79.60423278808594, 81.5647201538086, 83.52394104003906, 85.4802474975586, 87.43193817138672, 89.37730407714844, 91.31458282470703, 93.24198150634766, 95.1576919555664, 97.05986022949219, 98.94664001464844, 100.81614685058594, 102.66647338867188, 104.49571990966797, 106.30196380615234, 108.08329772949219, 109.8377914428711, 111.56352233886719, 113.25859069824219, 114.92108917236328, 116.54913330078125, 118.14085388183594, 119.69440460205078, 121.20796203613281, 122.67974090576172, 124.10797119140625, 125.4909439086914, 126.82697296142578, 128.1144256591797, 129.35171508789062, 130.5372772216797, 131.669677734375, 132.74745178222656, 133.7692413330078, 134.73374938964844, 135.63975524902344, 136.48606872558594, 137.27162170410156, 137.99537658691406, 138.65638732910156, 139.25381469726562, 139.78684997558594, 140.25482177734375, 140.65707397460938, 140.9931182861328, 141.26248168945312, 141.46481323242188, 141.599853515625, 141.66741943359375, 141.66741943359375, 141.599853515625, 141.46481323242188, 141.26248168945312, 140.9931182861328, 140.65707397460938, 140.25482177734375, 139.78684997558594, 139.25381469726562, 138.65638732910156, 137.99537658691406, 137.27162170410156, 136.48606872558594, 135.63975524902344, 134.73374938964844, 133.7692413330078, 132.74745178222656, 131.669677734375, 130.5372772216797, 129.35171508789062, 128.1144256591797, 126.82697296142578, 125.4909439086914, 124.10797119140625, 122.67974090576172, 121.20796203613281, 119.69440460205078, 118.14085388183594, 116.54913330078125, 114.92108917236328, 113.25859069824219, 111.56352233886719, 109.8377914428711, 108.08329772949219, 106.30196380615234, 104.49571990966797, 102.66647338867188, 100.81614685058594, 98.94664001464844, 97.05986022949219, 95.1576919555664, 93.24198150634766, 91.31458282470703, 89.37730407714844, 87.43193817138672, 85.4802474975586, 83.52394104003906, 81.5647201538086, 79.60423278808594, 77.64407348632812, 75.6858139038086, 73.7309799194336, 71.7810287475586, 69.8373794555664, 67.90141296386719, 65.97443389892578, 64.0577163696289, 62.152462005615234, 60.25982666015625, 58.38090133666992, 56.516727447509766, 54.66828918457031, 52.83649826049805, 51.02223205566406, 49.226287841796875, 47.44942092895508, 45.69232177734375, 43.955623626708984, 42.239906311035156, 40.54569625854492, 38.87346267700195, 37.22362518310547, 35.59654235839844, 33.99253463745117, 32.411869049072266, 30.854766845703125, 29.321399688720703, 27.81189727783203, 26.326353073120117, 24.86481285095215, 23.42729377746582, 22.013771057128906, 20.624191284179688, 19.258464813232422, 17.916479110717773, 16.598094940185547, 15.303146362304688, 14.031449317932129, 12.782797813415527, 11.556971549987793, 10.353734970092773, 9.172842979431152, 8.014037132263184, 6.877054691314697, 5.761628150939941, 4.667484760284424, 3.5943539142608643, 2.5419647693634033, 1.510050654411316, 0.49835002422332764], dtype=torch.float32) # noqa: E501 +} +# fmt: on + + +class BMModel(SignalModel[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]): + """Base signal model using Bloch-McConnel simulation.""" + + static_off_resonance_offsets: torch.Tensor | None + + def __init__(self) -> None: + super().__init__() + self.sequence: BMCSequence + self.register_buffer('static_off_resonance_offsets', None) + self.b0 = 0.0 + + def forward( + self, + equilibrium_magnetization: torch.Tensor, + t1: torch.Tensor, + t2: torch.Tensor, + exchange_rate: torch.Tensor, + chemical_shift: torch.Tensor | None, + ) -> tuple[torch.Tensor]: + """Simulate the sequence.""" + + static_off_resonance = None + if self.static_off_resonance_offsets is not None: + static_off_resonance = torch.as_tensor( + -2 * torch.pi * self.static_off_resonance_offsets * (magnetic_field_to_lamor_frequency(self.b0) / 1e6) + ) + parameters = Parameters( + equilibrium_magnetization, t1, t2, exchange_rate, chemical_shift, static_off_resonance=static_off_resonance + ) + + _, signals = self.sequence(parameters) + if len(signals) == 1: + return (signals[0],) + return (torch.stack(signals, dim=0),) + + +class BMsimSingleBlock(BMModel): + """Bloch-McConnell sequence with one constant saturation block per offset.""" + + def __init__( + self, + offsets: torch.Tensor | Sequence[float], + pulse_duration: float, + b1: float, + post_pulse_delay: float = 6.5e-3, + b0: float = 3.0, + readout_pool: int = 0, + ) -> None: + super().__init__() + offsets = torch.as_tensor(offsets) + rf_amplitude_hz = torch.as_tensor(b1 * GYROMAGNETIC_RATIO_PROTON, dtype=offsets.dtype) + + sequence = BMCSequence() + sequence.append(ConstantRFBlock(duration=pulse_duration, rf_amplitude=rf_amplitude_hz, rf_frequency=0.0)) + sequence.append(SpoilBlock(post_pulse_delay)) + sequence.append(LongitudinalReadoutBlock(pool_index=readout_pool)) + + self.static_off_resonance_offsets = offsets + self.b0 = b0 + self.sequence = sequence + + +class BMsimSingleGaussianPulse(BMModel): + """Bloch-McConnell sequence with one shaped pulse per offset.""" + + def __init__( + self, + offsets: torch.Tensor | Sequence[float], + pulse_envelope: torch.Tensor | Sequence[float], + dt: float, + post_pulse_delay: float = 6.5e-3, + b0: float = 3.0, + readout_pool: int = 0, + ) -> None: + super().__init__() + offsets = torch.as_tensor(offsets) + pulse_envelope = torch.as_tensor(pulse_envelope) + + sequence = BMCSequence() + sequence.append(PiecewiseRFBlock(rf_amplitude=pulse_envelope, dt=dt)) + sequence.append(SpoilBlock(post_pulse_delay)) + sequence.append(LongitudinalReadoutBlock(pool_index=readout_pool)) + + self.static_off_resonance_offsets = offsets + self.sequence = sequence + self.b0 = b0 + + +class BMsimGaussianPulseTrain(BMModel): + """Bloch-McConnell sequence with a train of shaped pulses per offset.""" + + def __init__( + self, + offsets: torch.Tensor | Sequence[float], + pulse_envelope: torch.Tensor | Sequence[float], + dt: float, + n_pulses: int, + interpulse_delay: float, + post_pulse_delay: float = 6.5e-3, + b0: float = 3.0, + readout_pool: int = 0, + ) -> None: + super().__init__() + offsets = torch.as_tensor(offsets) + pulse_envelope = torch.as_tensor(pulse_envelope) + larmor_frequency_hz = magnetic_field_to_lamor_frequency(b0) + offsets_hz = offsets * (larmor_frequency_hz / 1e6) + pulse_duration = float(dt) * int(pulse_envelope.numel()) + + sequence = BMCSequence() + sequence.append(ResetBlock()) + accumulated_phase = torch.zeros_like(offsets_hz) + for pulse_index in range(n_pulses): + sequence.append( + PiecewiseRFBlock( + rf_amplitude=pulse_envelope, + rf_phase=accumulated_phase[None, :], + rf_frequency=offsets_hz[None, :], + dt=dt, + ) + ) + accumulated_phase = (accumulated_phase + 2 * torch.pi * offsets_hz * pulse_duration) % (2 * torch.pi) + if pulse_index < n_pulses - 1: + sequence.append(DelayBlock(interpulse_delay)) + sequence.append(SpoilBlock(post_pulse_delay)) + sequence.append(LongitudinalReadoutBlock(pool_index=readout_pool)) + + self.sequence = sequence + + +class BMsimTwoPulseWASABI(BMModel): + """Bloch-McConnell sequence with two block pulses per offset.""" + + def __init__( + self, + offsets: torch.Tensor | Sequence[float], + pulse_duration: float, + b1: float, + interpulse_delay: float, + post_pulse_delay: float = 6.5e-3, + b0: float = 3.0, + readout_pool: int = 0, + ) -> None: + super().__init__() + offsets = torch.as_tensor(offsets) + rf_amplitude_hz = torch.as_tensor(b1 * GYROMAGNETIC_RATIO_PROTON, dtype=offsets.dtype) + larmor_frequency_hz = magnetic_field_to_lamor_frequency(b0) + offsets_hz = offsets * (larmor_frequency_hz / 1e6) + + sequence = BMCSequence() + sequence.append(ResetBlock()) + for pulse_index in range(2): + sequence.append( + ConstantRFBlock( + duration=pulse_duration, + rf_amplitude=rf_amplitude_hz, + rf_frequency=offsets_hz, + ) + ) + if pulse_index == 0: + sequence.append(DelayBlock(interpulse_delay)) + sequence.append(SpoilBlock(post_pulse_delay)) + sequence.append(LongitudinalReadoutBlock(pool_index=readout_pool)) + + self.offsets = offsets + self.sequence = sequence + + +class IsochromatGradientReadoutModel( + SignalModel[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None] +): + """Test-only model for analytic isochromat gradient comparisons.""" + + def __init__( + self, + positions: SpatialDimension[torch.Tensor], + gradient_blocks: Sequence[tuple[float, float, float, float]], + initial_phase: float = 0.0, + readout_pool: int = 0, + ) -> None: + super().__init__() + self.positions = positions.apply(torch.as_tensor) + self.initial_phase = float(initial_phase) + self.readout_pool = readout_pool + self.sequence = BMCSequence( + [ + *[ + GradientBlock( + duration=duration, + positions=self.positions, + gradient_z=gradient_z, + gradient_y=gradient_y, + gradient_x=gradient_x, + ) + for duration, gradient_z, gradient_y, gradient_x in gradient_blocks + ], + AcquisitionBlock(pool_index=readout_pool), + ] + ) + + def forward( + self, + equilibrium_magnetization: torch.Tensor, + t1: torch.Tensor, + t2: torch.Tensor, + exchange_rate: torch.Tensor, + chemical_shift: torch.Tensor | None, + ) -> tuple[torch.Tensor]: + parameters = Parameters(equilibrium_magnetization, t1, t2, exchange_rate, chemical_shift) + state = torch.zeros( + self.positions.z.numel(), + parameters.n_pools, + 3, + device=equilibrium_magnetization.device, + dtype=equilibrium_magnetization.dtype, + ) + state[..., self.readout_pool, 0] = torch.cos( + torch.tensor(self.initial_phase, dtype=state.dtype, device=state.device) + ) + state[..., self.readout_pool, 1] = torch.sin( + torch.tensor(self.initial_phase, dtype=state.dtype, device=state.device) + ) + _, (signal,) = self.sequence(parameters, state) + return (signal,) + + +def test_gradient_block_zero_gradient_matches_delay() -> None: + """Test that a gradient block with zero gradient matches a delay block.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0]), + t1=torch.tensor([1.2]), + t2=torch.tensor([0.08]), + exchange_rate=torch.tensor([[0.0]]), + chemical_shift=torch.tensor([0.0]), + ) + state = initial_state(parameters).expand(5, -1, -1).clone() + positions = SpatialDimension( + z=torch.linspace(-2e-3, 2e-3, 5), + y=torch.zeros(5), + x=torch.zeros(5), + ) + gradient_block = GradientBlock( + duration=4e-3, + gradient_z=torch.tensor(0.0), + gradient_y=torch.tensor(0.0), + gradient_x=torch.tensor(0.0), + positions=positions, + ) + delay_block = DelayBlock(4e-3) + + gradient_state, _ = gradient_block(parameters, state) + delay_state, _ = delay_block(parameters, state) + + torch.testing.assert_close(gradient_state, delay_state) + + +def test_isochromat_gradient_model_matches_analytic_signal() -> None: + """Compare the isochromat gradient model to the analytic signal.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0]), + t1=torch.tensor([1e12]), + t2=torch.tensor([1e12]), + exchange_rate=torch.tensor([[0.0]]), + chemical_shift=torch.tensor([0.0]), + ) + positions = SpatialDimension( + z=torch.tensor([-2e-3, 0.0, 2e-3]), + y=torch.tensor([1e-3, -1e-3, 2e-3]), + x=torch.tensor([0.5e-3, -0.5e-3, 1e-3]), + ) + duration = 3e-3 + gradient_z = torch.tensor(0.012) + gradient_y = torch.tensor(-0.007) + gradient_x = torch.tensor(0.004) + model = IsochromatGradientReadoutModel( + positions=positions, + gradient_blocks=[(duration, float(gradient_z), float(gradient_y), float(gradient_x))], + ) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + + phase = duration * gradient_to_extra_off_resonance(gradient_z, gradient_y, gradient_x, positions) + expected_signal = torch.exp(1j * phase).mean() + torch.testing.assert_close(signal, expected_signal, atol=1e-5, rtol=1e-5) + + +def test_piecewise_rf_block_extra_off_resonance_shape() -> None: + """Test that the piecewise RF block with extra off-resonance batches correctly.""" + parameters = Parameters( + equilibrium_magnetization=torch.ones(5, 1), + t1=torch.full((5, 1), 1.2), + t2=torch.full((5, 1), 0.08), + exchange_rate=torch.zeros(5, 1, 1), + chemical_shift=torch.zeros(5, 1), + ) + state = initial_state(parameters).clone() + block = PiecewiseRFBlock( + rf_amplitude=torch.tensor([1.0, 0.5]), + dt=2.5e-4, + extra_off_resonance=torch.zeros(2, 5), + ) + out, _ = block(parameters, state) + assert out.shape[0] == 5 + assert out.shape[-3:] == (1, 1, 3) + + +def test_slice_selective_rf_flat_profile_matches_piecewise_rf_block() -> None: + """Test that the slice selective RF block with a flat profile matches the piecewise RF block.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0]), + t1=torch.tensor([1.2]), + t2=torch.tensor([0.08]), + exchange_rate=torch.tensor([[0.0]]), + chemical_shift=torch.tensor([0.0]), + ) + state = initial_state(parameters).expand(5, -1, -1).clone() + positions = SpatialDimension( + z=torch.linspace(-2e-3, 2e-3, 5), + y=torch.zeros(5), + x=torch.zeros(5), + ) + amp = torch.tensor([0.3, 0.8, 1.0, 0.8, 0.3]) + profile = SliceSmoothedRectangular(fwhm_rect=1.0, fwhm_gauss=0.0) + + rf_state, _ = PiecewiseRFBlock(rf_amplitude=amp, dt=2.5e-4)(parameters, state) + profile_state, _ = SliceSelectiveRFBlock( + rf_amplitude=amp, + slice_profile=profile, + positions=positions, + dt=2.5e-4, + )(parameters, state) + + torch.testing.assert_close(profile_state, rf_state) + + +@pytest.mark.parametrize('slice_fwhm', [4e-3, 7e-3]) +def test_slice_profile_approx_rf(slice_fwhm: float) -> None: + """Test that the slice profile approximation matches the explicit RF block.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0]), + t1=torch.tensor([1.2]), + t2=torch.tensor([0.08]), + exchange_rate=torch.tensor([[0.0]]), + chemical_shift=torch.tensor([0.0]), + ) + n_iso = 101 + positions = SpatialDimension( + z=torch.linspace(-1.5 * slice_fwhm, 1.5 * slice_fwhm, n_iso), + y=torch.zeros(n_iso), + x=torch.zeros(n_iso), + ) + state = initial_state(parameters).expand(n_iso, -1, -1).clone() + + flip_angle = torch.deg2rad(torch.tensor(5.0)) + duration = 2e-3 + dt = 10e-6 + rf_pulse = GaussianRFPulse(fwhm_fraction=0.35) + waveform = rf_pulse(flip_angle=flip_angle, duration=duration, dt=dt) + + gradient_z = ( + 4 + * torch.log(torch.tensor(2.0)) + / (torch.pi * rf_pulse.fwhm_fraction * duration * GYROMAGNETIC_RATIO_PROTON * slice_fwhm) + ) + extra_off_resonance = gradient_to_extra_off_resonance(gradient_z, None, None, positions) + + explicit_state, _ = PiecewiseRFBlock( + rf_amplitude=waveform[..., None], + dt=dt, + extra_off_resonance=extra_off_resonance[None, :], + )(parameters, state.unsqueeze(-3)) + explicit_state = explicit_state.squeeze(-3) + + approximate_state, _ = SliceSelectiveRFBlock( + rf_amplitude=waveform, + slice_profile=SliceGaussian(fwhm=slice_fwhm), + positions=positions, + dt=dt, + )(parameters, state) + + explicit_profile = torch.linalg.vector_norm(explicit_state[..., 0, :2], dim=-1) + approximate_profile = torch.linalg.vector_norm(approximate_state[..., 0, :2], dim=-1) + explicit_profile = explicit_profile / explicit_profile.max() + approximate_profile = approximate_profile / approximate_profile.max() + + torch.testing.assert_close(approximate_profile, explicit_profile, atol=7e-2, rtol=7e-2) + + +@pytest.mark.parametrize('pulse_cls', [GaussianRFPulse, SincRFPulse]) +def test_slice_rf_pulse_templates_match_flip_angle_and_zero_phase(pulse_cls: type[SliceRFPulseBase]) -> None: + """Test that the slice RF pulse templates match the flip angle and have zero intrinsic phase.""" + flip_angle = torch.tensor(torch.pi / 3) + dt = torch.tensor(10e-6) + duration = torch.tensor(2e-3) + pulse = pulse_cls() + rf, phase = pulse.rf_and_phase(flip_angle=flip_angle, duration=duration, dt=dt) + assert rf.shape == phase.shape + torch.testing.assert_close(phase, torch.zeros_like(phase)) # no intrinsic phase + torch.testing.assert_close(rf, pulse(flip_angle=flip_angle, duration=duration, dt=dt)) + + achieved_flip_angle = GYROMAGNETIC_RATIO_PROTON * dt * (rf * torch.exp(1j * phase)).sum() + expected_flip_angle = torch.complex(flip_angle, torch.zeros_like(flip_angle)) + torch.testing.assert_close(achieved_flip_angle, expected_flip_angle, atol=1e-6, rtol=1e-6) + + +def test_gradient_to_extra_off_resonance() -> None: + positions = SpatialDimension( + z=torch.tensor([-1e-3, 1e-3, 2e-3]), + y=torch.zeros(3), + x=torch.zeros(3), + ) + gradient_z = torch.tensor([0.01, 0.02]) + expected = 2 * torch.pi * GYROMAGNETIC_RATIO_PROTON * gradient_z[:, None] * positions.z[None, :] + actual = gradient_to_extra_off_resonance(gradient_z, None, None, positions) + torch.testing.assert_close(actual, expected) + + +def test_bmsim_case_1_basic() -> None: + """Basic smoke test for BMsim case 1.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 5.0e-4]), + t1=torch.tensor([3.0, 1.05]), + t2=torch.tensor([2.0, 0.1]), + exchange_rate=torch.tensor([[0.0, 50.0], [5.0e-4 * 50.0, 0.0]]), + chemical_shift=torch.tensor([0.0, 1.9 * 42.5764 * 3.0]), + ) + offsets = torch.cat((torch.tensor([-300.0]), torch.linspace(-15.0, 15.0, 301))) + model = BMsimSingleBlock(offsets=offsets, pulse_duration=15.0, b1=2.0e-6) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + assert signal.shape == (302,) + assert signal.isfinite().all() + signal_norm = signal / signal[0] + torch.testing.assert_close(signal_norm[0], torch.tensor(1.0)) + torch.testing.assert_close(model.sequence.duration, torch.tensor(15.0 + 6.5e-3), atol=1e-6, rtol=0) + + +def test_bmsim_case_2_basic() -> None: + """Basic smoke test for BMsim case 2.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 5.0e-4]), + t1=torch.tensor([3.0, 1.05]), + t2=torch.tensor([2.0, 0.1]), + exchange_rate=torch.tensor([[0.0, 50.0], [5.0e-4 * 50.0, 0.0]]), + chemical_shift=torch.tensor([0.0, 1.9 * 42.5764 * 3.0]), + ) + offsets = torch.cat((torch.tensor([-300.0]), torch.linspace(-15.0, 15.0, 301))) + model = BMsimSingleBlock(offsets=offsets, pulse_duration=2.0, b1=2.0e-6) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + assert signal.shape == (302,) + assert signal.isfinite().all() + signal_norm = signal / signal[0] + torch.testing.assert_close(signal_norm[0], torch.tensor(1.0)) + torch.testing.assert_close(model.sequence.duration, torch.tensor(2.0 + 6.5e-3), atol=1e-6, rtol=0) + + +def test_bmsim_case_3_basic() -> None: + """Basic smoke test for BMsim case 3.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 0.1351, 0.0009009, 0.0009009, 0.0045]), + t1=torch.tensor([1.0, 1.0, 1.0, 1.0, 1.3]), + t2=torch.tensor([0.040, 4.0e-5, 0.1, 0.1, 0.005]), + exchange_rate=torch.tensor( + [ + [0.0, 30.0, 50.0, 1000.0, 20.0], + [0.1351 * 30.0, 0.0, 0.0, 0.0, 0.0], + [0.0009009 * 50.0, 0.0, 0.0, 0.0, 0.0], + [0.0009009 * 1000.0, 0.0, 0.0, 0.0, 0.0], + [0.0045 * 20.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + chemical_shift=torch.tensor( + [0.0, -3.0 * 42.5764 * 3.0, 3.5 * 42.5764 * 3.0, 2.0 * 42.5764 * 3.0, -3.0 * 42.5764 * 3.0] + ), + ) + offsets = torch.cat((torch.tensor([-300.0]), torch.linspace(-15.0, 15.0, 301))) + model = BMsimSingleBlock(offsets=offsets, pulse_duration=2.0, b1=2.0e-6) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + assert signal.shape == (302,) + assert signal.isfinite().all() + signal_norm = signal / signal[0] + torch.testing.assert_close(signal_norm[0], torch.tensor(1.0)) + + +def test_bmsim_case_4_basic() -> None: + """Basic smoke test for BMsim case 4.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 0.1351, 0.0009009, 0.0009009, 0.0045]), + t1=torch.tensor([1.0, 1.0, 1.0, 1.0, 1.3]), + t2=torch.tensor([0.040, 4.0e-5, 0.1, 0.1, 0.005]), + exchange_rate=torch.tensor( + [ + [0.0, 30.0, 50.0, 1000.0, 20.0], + [0.1351 * 30.0, 0.0, 0.0, 0.0, 0.0], + [0.0009009 * 50.0, 0.0, 0.0, 0.0, 0.0], + [0.0009009 * 1000.0, 0.0, 0.0, 0.0, 0.0], + [0.0045 * 20.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + chemical_shift=torch.tensor( + [0.0, -3.0 * 42.5764 * 3.0, 3.5 * 42.5764 * 3.0, 2.0 * 42.5764 * 3.0, -3.0 * 42.5764 * 3.0] + ), + ) + offsets = torch.cat((torch.tensor([-300.0]), torch.linspace(-2.0, 2.0, 81))) + model = BMsimSingleBlock(offsets=offsets, pulse_duration=5e-3, b1=3.7e-6) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + assert signal.shape == (82,) + assert signal.isfinite().all() + signal_norm = signal / signal[0] + torch.testing.assert_close(signal_norm[0], torch.tensor(1.0)) + torch.testing.assert_close(model.sequence.duration, torch.tensor(5e-3 + 6.5e-3), atol=1e-6, rtol=0) + + +def test_bmsim_case_5_basic() -> None: + """Basic smoke test for BMsim case 5.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 5.0e-4]), + t1=torch.tensor([3.0, 1.05]), + t2=torch.tensor([2.0, 0.1]), + exchange_rate=torch.tensor([[0.0, 50.0], [5.0e-4 * 50.0, 0.0]]), + chemical_shift=torch.tensor([0.0, 1.9 * 42.5764 * 3.0]), + ) + offsets = torch.cat((torch.tensor([-300.0]), torch.linspace(-2.0, 2.0, 201))) + model = BMsimSingleGaussianPulse(offsets=offsets, pulse_envelope=PULSES[5], dt=250e-6) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + assert signal.shape == (202,) + assert signal.isfinite().all() + torch.testing.assert_close((signal / signal[0])[0], torch.tensor(1.0)) + torch.testing.assert_close(model.sequence.duration, torch.tensor(50e-3 + 6.5e-3), atol=1e-6, rtol=0) + + +def test_bmsim_case_6_basic() -> None: + """Basic smoke test for BMsim case 6.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 5.0e-4]), + t1=torch.tensor([3.0, 1.05]), + t2=torch.tensor([2.0, 0.1]), + exchange_rate=torch.tensor([[0.0, 50.0], [5.0e-4 * 50.0, 0.0]]), + chemical_shift=torch.tensor([0.0, 1.9 * 42.5764 * 3.0]), + ) + offsets = torch.cat((torch.tensor([-300.0]), torch.linspace(-15.0, 15.0, 301))) + model = BMsimGaussianPulseTrain( + offsets=offsets, + pulse_envelope=PULSES[6], + dt=250e-6, + n_pulses=36, + interpulse_delay=5e-3, + ) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + assert signal.shape == (302,) + assert signal.isfinite().all() + torch.testing.assert_close((signal / signal[0])[0], torch.tensor(1.0)) + torch.testing.assert_close(model.sequence.duration, torch.tensor(1.975 + 6.5e-3), atol=1e-6, rtol=0) + + +def test_bmsim_case_7_basic() -> None: + """Basic smoke test for BMsim case 7.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 0.0009009, 0.0009009, 0.0045, 0.1351]), + t1=torch.tensor([1.0, 1.0, 1.0, 1.3, 1.0]), + t2=torch.tensor([0.040, 0.1, 0.1, 0.005, 4.0e-5]), + exchange_rate=torch.tensor( + [ + [0.0, 50.0, 1000.0, 20.0, 30.0], + [0.0009009 * 50.0, 0.0, 0.0, 0.0, 0.0], + [0.0009009 * 1000.0, 0.0, 0.0, 0.0, 0.0], + [0.0045 * 20.0, 0.0, 0.0, 0.0, 0.0], + [0.1351 * 30.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + chemical_shift=torch.tensor( + [0.0, 3.5 * 42.5764 * 3.0, 2.0 * 42.5764 * 3.0, -3.0 * 42.5764 * 3.0, -3.0 * 42.5764 * 3.0] + ), + ) + + offsets = torch.cat((torch.tensor([-300.0]), torch.linspace(-15.0, 15.0, 301))) + model = BMsimGaussianPulseTrain( + offsets=offsets, + pulse_envelope=PULSES[6], + dt=250e-6, + n_pulses=36, + interpulse_delay=5e-3, + ) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + assert signal.shape == (302,) + assert signal.isfinite().all() + torch.testing.assert_close((signal / signal[0])[0], torch.tensor(1.0)) + torch.testing.assert_close(model.sequence.duration, torch.tensor(1.975 + 6.5e-3), atol=1e-6, rtol=0) + + +def test_bmsim_case_8_basic() -> None: + """Basic smoke test for BMsim case 8.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 0.0009009, 0.0009009, 0.0045, 0.1351]), + t1=torch.tensor([1.0, 1.0, 1.0, 1.3, 1.0]), + t2=torch.tensor([0.040, 0.1, 0.1, 0.005, 4.0e-5]), + exchange_rate=torch.tensor( + [ + [0.0, 50.0, 1000.0, 20.0, 30.0], + [0.0009009 * 50.0, 0.0, 0.0, 0.0, 0.0], + [0.0009009 * 1000.0, 0.0, 0.0, 0.0, 0.0], + [0.0045 * 20.0, 0.0, 0.0, 0.0, 0.0], + [0.1351 * 30.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + chemical_shift=torch.tensor( + [0.0, 3.5 * 42.5764 * 3.0, 2.0 * 42.5764 * 3.0, -3.0 * 42.5764 * 3.0, -3.0 * 42.5764 * 3.0] + ), + ) + offsets = torch.cat((torch.tensor([-300.0]), torch.linspace(-2.0, 2.0, 81))) + model = BMsimTwoPulseWASABI( + offsets=offsets, + pulse_duration=5e-3, + b1=3.7e-6, + interpulse_delay=100e-6, + ) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + assert signal.shape == (82,) + assert signal.isfinite().all() + torch.testing.assert_close((signal / signal[0])[0], torch.tensor(1.0)) + torch.testing.assert_close(model.sequence.duration, torch.tensor(10.1e-3 + 6.5e-3), atol=1e-6, rtol=0) + + +def test_bmsim_case_1_reference() -> None: + """Compare BMsim case 1 against the Zaiss reference.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 5.0e-4]), + t1=torch.tensor([3.0, 1.05]), + t2=torch.tensor([2.0, 0.1]), + exchange_rate=torch.tensor([[0.0, 50.0], [5.0e-4 * 50.0, 0.0]]), + chemical_shift=torch.tensor([0.0, 1.9 * 42.5764 * 3.0]), + ) + model = BMsimSingleBlock( + offsets=torch.cat((torch.tensor([-300.0]), torch.linspace(-15.0, 15.0, 301))), + pulse_duration=15.0, + b1=2.0e-6, + ) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + reference = torch.tensor(REFERENCE[1], dtype=signal.dtype) + diff = (signal.cpu() - reference).abs() + assert float(diff.max()) < 1.1e-2, f'case 1: max_abs={float(diff.max()):.6g}, mean_abs={float(diff.mean()):.6g}' + + +def test_bmsim_case_2_reference() -> None: + """Compare BMsim case 2 against the Zaiss reference.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 5.0e-4]), + t1=torch.tensor([3.0, 1.05]), + t2=torch.tensor([2.0, 0.1]), + exchange_rate=torch.tensor([[0.0, 50.0], [5.0e-4 * 50.0, 0.0]]), + chemical_shift=torch.tensor([0.0, 1.9 * 42.5764 * 3.0]), + ) + model = BMsimSingleBlock( + offsets=torch.cat((torch.tensor([-300.0]), torch.linspace(-15.0, 15.0, 301))), + pulse_duration=2.0, + b1=2.0e-6, + ) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + reference = torch.tensor(REFERENCE[2], dtype=signal.dtype) + diff = (signal.cpu() - reference).abs() + assert float(diff.max()) < 1.9e-2, f'case 2: max_abs={float(diff.max()):.6g}, mean_abs={float(diff.mean()):.6g}' + + +def test_bmsim_case_3_reference() -> None: + """Compare BMsim case 3 with full x,y,z MT against the Zaiss reference.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 0.1351, 0.0009009, 0.0009009, 0.0045]), + t1=torch.tensor([1.0, 1.0, 1.0, 1.0, 1.3]), + t2=torch.tensor([0.040, 4.0e-5, 0.1, 0.1, 0.005]), + exchange_rate=torch.tensor( + [ + [0.0, 30.0, 50.0, 1000.0, 20.0], + [0.1351 * 30.0, 0.0, 0.0, 0.0, 0.0], + [0.0009009 * 50.0, 0.0, 0.0, 0.0, 0.0], + [0.0009009 * 1000.0, 0.0, 0.0, 0.0, 0.0], + [0.0045 * 20.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + chemical_shift=torch.tensor( + [0.0, -3.0 * 42.5764 * 3.0, 3.5 * 42.5764 * 3.0, 2.0 * 42.5764 * 3.0, -3.0 * 42.5764 * 3.0] + ), + ) + model = BMsimSingleBlock( + offsets=torch.cat((torch.tensor([-300.0]), torch.linspace(-15.0, 15.0, 301))), + pulse_duration=2.0, + b1=2.0e-6, + ) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + reference = torch.tensor(REFERENCE[3], dtype=signal.dtype) + diff = (signal.cpu() - reference).abs() + assert float(diff.max()) < 9e-3, ( + f'case 3 full_mt: max_abs={float(diff.max()):.6g}, mean_abs={float(diff.mean()):.6g}' + ) + + +def test_bmsim_case_4_reference() -> None: + """Compare BMsim case 4 with full x,y,z MT against the Zaiss reference.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 0.1351, 0.0009009, 0.0009009, 0.0045]), + t1=torch.tensor([1.0, 1.0, 1.0, 1.0, 1.3]), + t2=torch.tensor([0.040, 4.0e-5, 0.1, 0.1, 0.005]), + exchange_rate=torch.tensor( + [ + [0.0, 30.0, 50.0, 1000.0, 20.0], + [0.1351 * 30.0, 0.0, 0.0, 0.0, 0.0], + [0.0009009 * 50.0, 0.0, 0.0, 0.0, 0.0], + [0.0009009 * 1000.0, 0.0, 0.0, 0.0, 0.0], + [0.0045 * 20.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + chemical_shift=torch.tensor( + [0.0, -3.0 * 42.5764 * 3.0, 3.5 * 42.5764 * 3.0, 2.0 * 42.5764 * 3.0, -3.0 * 42.5764 * 3.0] + ), + ) + model = BMsimSingleBlock( + offsets=torch.cat((torch.tensor([-300.0]), torch.linspace(-2.0, 2.0, 81))), + pulse_duration=5e-3, + b1=3.7e-6, + ) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + reference = torch.tensor(REFERENCE[4], dtype=signal.dtype) + diff = (signal.cpu() - reference).abs() + assert float(diff.max()) < 3e-4, ( + f'case 4 full_mt: max_abs={float(diff.max()):.6g}, mean_abs={float(diff.mean()):.6g}' + ) + + +def test_bmsim_case_5_reference() -> None: + """Compare BMsim case 5 against the Zaiss reference.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 5.0e-4]), + t1=torch.tensor([3.0, 1.05]), + t2=torch.tensor([2.0, 0.1]), + exchange_rate=torch.tensor([[0.0, 50.0], [5.0e-4 * 50.0, 0.0]]), + chemical_shift=torch.tensor([0.0, 1.9 * 42.5764 * 3.0]), + ) + model = BMsimSingleGaussianPulse( + offsets=torch.cat((torch.tensor([-300.0]), torch.linspace(-2.0, 2.0, 201))), + pulse_envelope=PULSES[5], + dt=250e-6, + ) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + reference = torch.tensor(REFERENCE[5], dtype=signal.dtype) + diff = (signal.cpu() - reference).abs() + assert float(diff.max()) < 4e-4, f'case 5: max_abs={float(diff.max()):.6g}, mean_abs={float(diff.mean()):.6g}' + + +def test_bmsim_case_6_reference() -> None: + """Compare BMsim case 6 against the Zaiss reference.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 5.0e-4]), + t1=torch.tensor([3.0, 1.05]), + t2=torch.tensor([2.0, 0.1]), + exchange_rate=torch.tensor([[0.0, 50.0], [5.0e-4 * 50.0, 0.0]]), + chemical_shift=torch.tensor([0.0, 1.9 * 42.5764 * 3.0]), + ) + model = BMsimGaussianPulseTrain( + offsets=torch.cat((torch.tensor([-300.0]), torch.linspace(-15.0, 15.0, 301))), + pulse_envelope=PULSES[6], + dt=250e-6, + n_pulses=36, + interpulse_delay=5e-3, + ) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + reference = torch.tensor(REFERENCE[6], dtype=signal.dtype) + diff = (signal.cpu() - reference).abs() + assert float(diff.max()) < 4.2e-1, f'case 6: max_abs={float(diff.max()):.6g}, mean_abs={float(diff.mean()):.6g}' + + +def test_bmsim_case_7_reference() -> None: + """Compare BMsim case 7 with full x,y,z MT against the Zaiss reference.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 0.0009009, 0.0009009, 0.0045, 0.1351]), + t1=torch.tensor([1.0, 1.0, 1.0, 1.3, 1.0]), + t2=torch.tensor([0.040, 0.1, 0.1, 0.005, 4.0e-5]), + exchange_rate=torch.tensor( + [ + [0.0, 50.0, 1000.0, 20.0, 30.0], + [0.0009009 * 50.0, 0.0, 0.0, 0.0, 0.0], + [0.0009009 * 1000.0, 0.0, 0.0, 0.0, 0.0], + [0.0045 * 20.0, 0.0, 0.0, 0.0, 0.0], + [0.1351 * 30.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + chemical_shift=torch.tensor( + [0.0, 3.5 * 42.5764 * 3.0, 2.0 * 42.5764 * 3.0, -3.0 * 42.5764 * 3.0, -3.0 * 42.5764 * 3.0] + ), + ) + model = BMsimGaussianPulseTrain( + offsets=torch.cat((torch.tensor([-300.0]), torch.linspace(-15.0, 15.0, 301))), + pulse_envelope=PULSES[6], + dt=250e-6, + n_pulses=36, + interpulse_delay=5e-3, + ) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + reference = torch.tensor(REFERENCE[7], dtype=signal.dtype) + diff = (signal.cpu() - reference).abs() + assert float(diff.max()) < 4e-3, ( + f'case 7 full_mt: max_abs={float(diff.max()):.6g}, mean_abs={float(diff.mean()):.6g}' + ) + + +def test_bmsim_case_8_reference() -> None: + """Compare BMsim case 8 with full x,y,z MT against the Zaiss reference.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 0.0009009, 0.0009009, 0.0045, 0.1351]), + t1=torch.tensor([1.0, 1.0, 1.0, 1.3, 1.0]), + t2=torch.tensor([0.040, 0.1, 0.1, 0.005, 4.0e-5]), + exchange_rate=torch.tensor( + [ + [0.0, 50.0, 1000.0, 20.0, 30.0], + [0.0009009 * 50.0, 0.0, 0.0, 0.0, 0.0], + [0.0009009 * 1000.0, 0.0, 0.0, 0.0, 0.0], + [0.0045 * 20.0, 0.0, 0.0, 0.0, 0.0], + [0.1351 * 30.0, 0.0, 0.0, 0.0, 0.0], + ] + ), + chemical_shift=torch.tensor( + [0.0, 3.5 * 42.5764 * 3.0, 2.0 * 42.5764 * 3.0, -3.0 * 42.5764 * 3.0, -3.0 * 42.5764 * 3.0] + ), + ) + model = BMsimTwoPulseWASABI( + offsets=torch.cat((torch.tensor([-300.0]), torch.linspace(-2.0, 2.0, 81))), + pulse_duration=5e-3, + b1=3.7e-6, + interpulse_delay=100e-6, + ) + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + reference = torch.tensor(REFERENCE[8], dtype=signal.dtype) + diff = (signal.cpu() - reference).abs() + assert float(diff.max()) < 1e-2, ( + f'case 8 full_mt: max_abs={float(diff.max()):.6g}, mean_abs={float(diff.mean()):.6g}' + ) + + +@pytest.mark.cuda +def test_bmsim_cuda() -> None: + """Test BMsim models work on cuda devices.""" + parameters = Parameters( + equilibrium_magnetization=torch.tensor([1.0, 5.0e-4], device='cuda'), + t1=torch.tensor([3.0, 1.05], device='cuda'), + t2=torch.tensor([2.0, 0.1], device='cuda'), + exchange_rate=torch.tensor([[0.0, 50.0], [5.0e-4 * 50.0, 0.0]], device='cuda'), + chemical_shift=torch.tensor([0.0, 1.9 * 42.5764 * 3.0], device='cuda'), + ) + model = BMsimSingleBlock(offsets=torch.tensor([-300.0, 0.0]), pulse_duration=2.0, b1=2.0e-6).cuda() + (signal,) = model( + parameters.equilibrium_magnetization, + parameters.t1, + parameters.t2, + parameters.exchange_rate, + parameters.chemical_shift, + ) + assert signal.is_cuda + assert signal.isfinite().all()